Compare commits
120 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c8a7617536 | |||
| ce642e92da | |||
| 6a509ba862 | |||
| 6d5beefe29 | |||
| 06beecafc5 | |||
| daf0a23958 | |||
| 38ced7ee59 | |||
| 23c98025b3 | |||
| 8cd7426e56 | |||
| fbce7aeb32 | |||
| 35fada4169 | |||
| fbe2fe5578 | |||
| c86511586f | |||
| 60892c55a4 | |||
| 8fe5a14d9b | |||
| 58431f102c | |||
| 4a9ab650aa | |||
| 0ac1d5b482 | |||
| 7567adfc45 | |||
| 3da98e7ee3 | |||
| b3b04fefde | |||
| 0e3f2713c2 | |||
| a7e9f85e21 | |||
| 9ce89e2efa | |||
| aa5f5d41d6 | |||
| bd96a084d3 | |||
| b863bdd6ca | |||
| f00a995753 | |||
| e8312e7ca9 | |||
| 7986834572 | |||
| edd7880418 | |||
| b4be42282d | |||
| a4f9c3cbc3 | |||
| 4b60f4b602 | |||
| 6cef71de3a | |||
| 026507c06c | |||
| 448c72a230 | |||
| f108ad8888 | |||
| e30d3bf544 | |||
| 6ab62c7431 | |||
| f59df3bb8b | |||
| a00c73a5e1 | |||
| 0434db9a99 | |||
| aff574fb29 | |||
| 79ea8eb258 | |||
| e7f3a73786 | |||
| 7a4a126db8 | |||
| d143851309 | |||
| 9ad1470d48 | |||
| bf99ab2f55 | |||
| ee842839ef | |||
| 96795afc72 | |||
| 12650e1393 | |||
| addaad013c | |||
| 485f8d1758 | |||
| cff0fd6260 | |||
| 8ddb20bfb8 | |||
| e5089d702b | |||
| 2c3e4eafa8 | |||
| c7020df2cf | |||
| 4bed3e306e | |||
| 00a3bc9d6c | |||
| ccb35acd81 | |||
| 00cae4e857 | |||
| b3fb4188f5 | |||
| 71df1581f7 | |||
| d046cf7d35 | |||
| 68a5185c86 | |||
| 6e2fe26bfd | |||
| 77b5fa59c5 | |||
| a226920b52 | |||
| 7007f72409 | |||
| a6804de4a2 | |||
| 7f897a9fc4 | |||
| 0966663d2a | |||
| fb78f4f12d | |||
| 2220af6940 | |||
| 7a34832d52 | |||
| e973de64f9 | |||
| db94ca882d | |||
| 6985906a2e | |||
| 54f410db6c | |||
| c12a05b9c1 | |||
| 2e0f5c86cc | |||
| 1d63306295 | |||
| 6c93626f6f | |||
| 72c5bf07c8 | |||
| ed59f90f15 | |||
| a09ca7f27e | |||
| 8c02572e16 | |||
| 27dde51de8 | |||
| 10d4a775f1 | |||
| 72d9a81d99 | |||
| 4fa85c7963 | |||
| 806e8e66fb | |||
| 0b90051db8 | |||
| b305c779b2 | |||
| 2b3cd2d39c | |||
| bc3d1c9ee6 | |||
| e50d614636 | |||
| a8df0f1ffb | |||
| ace53e2d2f | |||
| ffc2992fc2 | |||
| c70a285c2c | |||
| 8b811feece | |||
| 37e8dc7a59 | |||
| 024a9f5de3 | |||
| 005195c23e | |||
| 6742f160df | |||
| 540d303250 | |||
| f1b3036ca1 | |||
| 46ec1743a2 | |||
| 70272b1108 | |||
| 2b6dcbfa1d | |||
| af9572d759 | |||
| ddea157979 | |||
| ad3f9a26c0 | |||
| e8d0980f9f | |||
| 52a7f1cb97 | |||
| 33f85fadf6 |
@@ -180,6 +180,55 @@ jobs:
|
||||
pip install slack_sdk tabulate
|
||||
python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
run_torch_compile_tests:
|
||||
name: PyTorch Compile CUDA tests
|
||||
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-compile-cuda
|
||||
options: --gpus 0 --shm-size "16gb" --ipc host
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: NVIDIA-SMI
|
||||
run: |
|
||||
nvidia-smi
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test,training]
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
- name: Run torch compile tests on GPU
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
RUN_COMPILE: yes
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_torch_compile_cuda_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: torch_compile_test_reports
|
||||
path: reports
|
||||
|
||||
- name: Generate Report and Notify Channel
|
||||
if: always()
|
||||
run: |
|
||||
pip install slack_sdk tabulate
|
||||
python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
run_big_gpu_torch_tests:
|
||||
name: Torch tests on big GPU
|
||||
strategy:
|
||||
|
||||
@@ -335,7 +335,7 @@ jobs:
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
- name: Run example tests on GPU
|
||||
- name: Run torch compile tests on GPU
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
RUN_COMPILE: yes
|
||||
|
||||
@@ -28,6 +28,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
|
||||
- [`WanLoraLoaderMixin`] provides similar functions for [Wan](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan).
|
||||
- [`CogView4LoraLoaderMixin`] provides similar functions for [CogView4](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogview4).
|
||||
- [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`].
|
||||
- [`HiDreamImageLoraLoaderMixin`] provides similar functions for [HiDream Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hidream)
|
||||
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.
|
||||
|
||||
<Tip>
|
||||
@@ -91,6 +92,10 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.AmusedLoraLoaderMixin
|
||||
|
||||
## HiDreamImageLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.HiDreamImageLoraLoaderMixin
|
||||
|
||||
## LoraBaseMixin
|
||||
|
||||
[[autodoc]] loaders.lora_base.LoraBaseMixin
|
||||
@@ -347,7 +347,7 @@ image = pipe(
|
||||
height=1024,
|
||||
prompt="wearing sunglasses",
|
||||
negative_prompt="",
|
||||
true_cfg=4.0,
|
||||
true_cfg_scale=4.0,
|
||||
generator=torch.Generator().manual_seed(4444),
|
||||
ip_adapter_image=image,
|
||||
).images[0]
|
||||
|
||||
@@ -24,7 +24,7 @@
|
||||
|
||||
## Generating Videos with Wan 2.1
|
||||
|
||||
We will first need to install some addtional dependencies.
|
||||
We will first need to install some additional dependencies.
|
||||
|
||||
```shell
|
||||
pip install -u ftfy imageio-ffmpeg imageio
|
||||
|
||||
@@ -216,7 +216,7 @@ Setting the `<ID_TOKEN>` is not necessary. From some limited experimentation, we
|
||||
> - The original repository uses a `lora_alpha` of `1`. We found this not suitable in many runs, possibly due to difference in modeling backends and training settings. Our recommendation is to set to the `lora_alpha` to either `rank` or `rank // 2`.
|
||||
> - If you're training on data whose captions generate bad results with the original model, a `rank` of 64 and above is good and also the recommendation by the team behind CogVideoX. If the generations are already moderately good on your training captions, a `rank` of 16/32 should work. We found that setting the rank too low, say `4`, is not ideal and doesn't produce promising results.
|
||||
> - The authors of CogVideoX recommend 4000 training steps and 100 training videos overall to achieve the best result. While that might yield the best results, we found from our limited experimentation that 2000 steps and 25 videos could also be sufficient.
|
||||
> - When using the Prodigy opitimizer for training, one can follow the recommendations from [this](https://huggingface.co/blog/sdxl_lora_advanced_script) blog. Prodigy tends to overfit quickly. From my very limited testing, I found a learning rate of `0.5` to be suitable in addition to `--prodigy_use_bias_correction`, `prodigy_safeguard_warmup` and `--prodigy_decouple`.
|
||||
> - When using the Prodigy optimizer for training, one can follow the recommendations from [this](https://huggingface.co/blog/sdxl_lora_advanced_script) blog. Prodigy tends to overfit quickly. From my very limited testing, I found a learning rate of `0.5` to be suitable in addition to `--prodigy_use_bias_correction`, `prodigy_safeguard_warmup` and `--prodigy_decouple`.
|
||||
> - The recommended learning rate by the CogVideoX authors and from our experimentation with Adam/AdamW is between `1e-3` and `1e-4` for a dataset of 25+ videos.
|
||||
>
|
||||
> Note that our testing is not exhaustive due to limited time for exploration. Our recommendation would be to play around with the different knobs and dials to find the best settings for your data.
|
||||
|
||||
@@ -589,7 +589,7 @@ For stage 2 of DeepFloyd IF with DreamBooth, pay attention to these parameters:
|
||||
|
||||
* `--learning_rate=5e-6`, use a lower learning rate with a smaller effective batch size
|
||||
* `--resolution=256`, the expected resolution for the upscaler
|
||||
* `--train_batch_size=2` and `--gradient_accumulation_steps=6`, to effectively train on images wiht faces requires larger batch sizes
|
||||
* `--train_batch_size=2` and `--gradient_accumulation_steps=6`, to effectively train on images with faces requires larger batch sizes
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="DeepFloyd/IF-II-L-v1.0"
|
||||
|
||||
@@ -89,7 +89,7 @@ Many of the basic and important parameters are described in the [Text-to-image](
|
||||
|
||||
As with the script parameters, a walkthrough of the training script is provided in the [Text-to-image](text2image#training-script) training guide. Instead, this guide takes a look at the T2I-Adapter relevant parts of the script.
|
||||
|
||||
The training script begins by preparing the dataset. This incudes [tokenizing](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L674) the prompt and [applying transforms](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L714) to the images and conditioning images.
|
||||
The training script begins by preparing the dataset. This includes [tokenizing](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L674) the prompt and [applying transforms](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L714) to the images and conditioning images.
|
||||
|
||||
```py
|
||||
conditioning_image_transforms = transforms.Compose(
|
||||
|
||||
@@ -2181,7 +2181,7 @@ def main(args):
|
||||
# Predict the noise residual
|
||||
model_pred = transformer(
|
||||
hidden_states=packed_noisy_model_input,
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
timestep=timesteps / 1000,
|
||||
guidance=guidance,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
|
||||
@@ -86,6 +86,7 @@ PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixar
|
||||
| Perturbed-Attention Guidance |StableDiffusionPAGPipeline is a modification of StableDiffusionPipeline to support Perturbed-Attention Guidance (PAG).|[Perturbed-Attention Guidance](#perturbed-attention-guidance)|[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/perturbed_attention_guidance.ipynb)|[Hyoungwon Cho](https://github.com/HyoungwonCho)|
|
||||
| CogVideoX DDIM Inversion Pipeline | Implementation of DDIM inversion and guided attention-based editing denoising process on CogVideoX. | [CogVideoX DDIM Inversion Pipeline](#cogvideox-ddim-inversion-pipeline) | - | [LittleNyima](https://github.com/LittleNyima) |
|
||||
| FaithDiff Stable Diffusion XL Pipeline | Implementation of [(CVPR 2025) FaithDiff: Unleashing Diffusion Priors for Faithful Image Super-resolutionUnleashing Diffusion Priors for Faithful Image Super-resolution](https://arxiv.org/abs/2411.18824) - FaithDiff is a faithful image super-resolution method that leverages latent diffusion models by actively adapting the diffusion prior and jointly fine-tuning its components (encoder and diffusion model) with an alignment module to ensure high fidelity and structural consistency. | [FaithDiff Stable Diffusion XL Pipeline](#faithdiff-stable-diffusion-xl-pipeline) | [](https://huggingface.co/jychen9811/FaithDiff) | [Junyang Chen, Jinshan Pan, Jiangxin Dong, IMAG Lab, (Adapted by Eliseu Silva)](https://github.com/JyChen9811/FaithDiff) |
|
||||
| Stable Diffusion 3 InstructPix2Pix Pipeline | Implementation of Stable Diffusion 3 InstructPix2Pix Pipeline | [Stable Diffusion 3 InstructPix2Pix Pipeline](#stable-diffusion-3-instructpix2pix-pipeline) | [](https://huggingface.co/BleachNick/SD3_UltraEdit_freeform) [](https://huggingface.co/CaptainZZZ/sd3-instructpix2pix) | [Jiayu Zhang](https://github.com/xduzhangjiayu) and [Haozhe Zhao](https://github.com/HaozheZhao)|
|
||||
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
|
||||
|
||||
```py
|
||||
@@ -5381,7 +5382,7 @@ pipe = DiffusionPipeline.from_pretrained(
|
||||
# Here we need use pipeline internal unet model
|
||||
pipe.unet = pipe.unet_model.from_pretrained(model_id, subfolder="unet", variant="fp16", use_safetensors=True)
|
||||
|
||||
# Load aditional layers to the model
|
||||
# Load additional layers to the model
|
||||
pipe.unet.load_additional_layers(weight_path="proc_data/faithdiff/FaithDiff.bin", dtype=dtype)
|
||||
|
||||
# Enable vae tiling
|
||||
@@ -5432,4 +5433,50 @@ cropped_image = gen_image.crop((0, 0, width_init, height_init))
|
||||
cropped_image.save("data/result.png")
|
||||
````
|
||||
### Result
|
||||
[<img src="https://huggingface.co/datasets/DEVAIEXP/assets/resolve/main/faithdiff_restored.PNG" width="512px" height="512px"/>](https://imgsli.com/MzY1NzE2)
|
||||
[<img src="https://huggingface.co/datasets/DEVAIEXP/assets/resolve/main/faithdiff_restored.PNG" width="512px" height="512px"/>](https://imgsli.com/MzY1NzE2)
|
||||
|
||||
|
||||
# Stable Diffusion 3 InstructPix2Pix Pipeline
|
||||
This the implementation of the Stable Diffusion 3 InstructPix2Pix Pipeline, based on the HuggingFace Diffusers.
|
||||
|
||||
## Example Usage
|
||||
This pipeline aims to edit image based on user's instruction by using SD3
|
||||
````py
|
||||
import torch
|
||||
from diffusers import SD3Transformer2DModel
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
|
||||
resolution = 512
|
||||
image = load_image("https://hf.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png").resize(
|
||||
(resolution, resolution)
|
||||
)
|
||||
edit_instruction = "Turn sky into a sunny one"
|
||||
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-3-medium-diffusers", custom_pipeline="pipeline_stable_diffusion_3_instruct_pix2pix", torch_dtype=torch.float16).to('cuda')
|
||||
|
||||
pipe.transformer = SD3Transformer2DModel.from_pretrained("CaptainZZZ/sd3-instructpix2pix",torch_dtype=torch.float16).to('cuda')
|
||||
|
||||
edited_image = pipe(
|
||||
prompt=edit_instruction,
|
||||
image=image,
|
||||
height=resolution,
|
||||
width=resolution,
|
||||
guidance_scale=7.5,
|
||||
image_guidance_scale=1.5,
|
||||
num_inference_steps=30,
|
||||
).images[0]
|
||||
|
||||
edited_image.save("edited_image.png")
|
||||
````
|
||||
|Original|Edited|
|
||||
|---|---|
|
||||
||
|
||||
|
||||
### Note
|
||||
This model is trained on 512x512, so input size is better on 512x512.
|
||||
For better editing performance, please refer to this powerful model https://huggingface.co/BleachNick/SD3_UltraEdit_freeform and Paper "UltraEdit: Instruction-based Fine-Grained Image
|
||||
Editing at Scale", many thanks to their contribution!
|
||||
@@ -312,9 +312,9 @@ if __name__ == "__main__":
|
||||
# These are the coordinates of the output image
|
||||
out_coordinates = np.arange(1, out_length + 1)
|
||||
|
||||
# since both scale-factor and output size can be provided simulatneously, perserving the center of the image requires shifting
|
||||
# the output coordinates. the deviation is because out_length doesn't necesary equal in_length*scale.
|
||||
# to keep the center we need to subtract half of this deivation so that we get equal margins for boths sides and center is preserved.
|
||||
# since both scale-factor and output size can be provided simultaneously, preserving the center of the image requires shifting
|
||||
# the output coordinates. the deviation is because out_length doesn't necessary equal in_length*scale.
|
||||
# to keep the center we need to subtract half of this deviation so that we get equal margins for both sides and center is preserved.
|
||||
shifted_out_coordinates = out_coordinates - (out_length - in_length * scale) / 2
|
||||
|
||||
# These are the matching positions of the output-coordinates on the input image coordinates.
|
||||
|
||||
@@ -351,7 +351,7 @@ def my_forward(
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
|
||||
added_cond_kwargs: (`dict`, *optional*):
|
||||
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
|
||||
A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
|
||||
are passed along to the UNet blocks.
|
||||
|
||||
Returns:
|
||||
@@ -864,9 +864,9 @@ def get_flow_and_interframe_paras(flow_model, imgs):
|
||||
class AttentionControl:
|
||||
"""
|
||||
Control FRESCO-based attention
|
||||
* enable/diable spatial-guided attention
|
||||
* enable/diable temporal-guided attention
|
||||
* enable/diable cross-frame attention
|
||||
* enable/disable spatial-guided attention
|
||||
* enable/disable temporal-guided attention
|
||||
* enable/disable cross-frame attention
|
||||
* collect intermediate attention feature (for spatial-guided attention)
|
||||
"""
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ class RASGAttnProcessor:
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
scale: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
# Same as the default AttnProcessor up untill the part where similarity matrix gets saved
|
||||
# Same as the default AttnProcessor up until the part where similarity matrix gets saved
|
||||
downscale_factor = self.mask_resoltuion // hidden_states.shape[1]
|
||||
residual = hidden_states
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -889,7 +889,7 @@ def main(args):
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with=args.report_to,
|
||||
project_config=accelerator_project_config,
|
||||
split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes
|
||||
split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be divided by the number of processes assuming batches are multiplied by the number of processes
|
||||
)
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
|
||||
@@ -721,7 +721,7 @@ def main(args):
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with=args.report_to,
|
||||
project_config=accelerator_project_config,
|
||||
split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes
|
||||
split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be divided by the number of processes assuming batches are multiplied by the number of processes
|
||||
)
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
|
||||
@@ -884,7 +884,7 @@ def main(args):
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with=args.report_to,
|
||||
project_config=accelerator_project_config,
|
||||
split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes
|
||||
split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be divided by the number of processes assuming batches are multiplied by the number of processes
|
||||
)
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
|
||||
@@ -854,7 +854,7 @@ def main(args):
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with=args.report_to,
|
||||
project_config=accelerator_project_config,
|
||||
split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes
|
||||
split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be divided by the number of processes assuming batches are multiplied by the number of processes
|
||||
)
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
|
||||
@@ -894,7 +894,7 @@ def main(args):
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with=args.report_to,
|
||||
project_config=accelerator_project_config,
|
||||
split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes
|
||||
split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be divided by the number of processes assuming batches are multiplied by the number of processes
|
||||
)
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
|
||||
@@ -6,7 +6,19 @@ Training script provided by LibAI, which is an institution dedicated to the prog
|
||||
> [!NOTE]
|
||||
> **Memory consumption**
|
||||
>
|
||||
> Flux can be quite expensive to run on consumer hardware devices and as a result, ControlNet training of it comes with higher memory requirements than usual.
|
||||
> Flux can be quite expensive to run on consumer hardware devices and as a result, ControlNet training of it comes with higher memory requirements than usual.
|
||||
|
||||
Here is a gpu memory consumption for reference, tested on a single A100 with 80G.
|
||||
|
||||
| period | GPU |
|
||||
| - | - |
|
||||
| load as float32 | ~70G |
|
||||
| mv transformer and vae to bf16 | ~48G |
|
||||
| pre compute txt embeddings | ~62G |
|
||||
| **offload te to cpu** | ~30G |
|
||||
| training | ~58G |
|
||||
| validation | ~71G |
|
||||
|
||||
|
||||
> **Gated access**
|
||||
>
|
||||
@@ -98,8 +110,9 @@ accelerate launch train_controlnet_flux.py \
|
||||
--validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
|
||||
--validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--gradient_accumulation_steps=16 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="cosine" \
|
||||
--num_double_layers=4 \
|
||||
--num_single_layers=0 \
|
||||
--seed=42 \
|
||||
|
||||
@@ -148,7 +148,7 @@ def log_validation(
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
control_image=validation_image,
|
||||
num_inference_steps=28,
|
||||
controlnet_conditioning_scale=0.7,
|
||||
controlnet_conditioning_scale=1,
|
||||
guidance_scale=3.5,
|
||||
generator=generator,
|
||||
).images[0]
|
||||
@@ -639,6 +639,15 @@ def parse_args(input_args=None):
|
||||
action="store_true",
|
||||
help="Enable model cpu offload and save memory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_interpolation_mode",
|
||||
type=str,
|
||||
default="lanczos",
|
||||
choices=[
|
||||
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
|
||||
],
|
||||
help="The image interpolation method to use for resizing images.",
|
||||
)
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
@@ -736,9 +745,13 @@ def get_train_dataset(args, accelerator):
|
||||
|
||||
|
||||
def prepare_train_dataset(dataset, accelerator):
|
||||
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
|
||||
if interpolation is None:
|
||||
raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
|
||||
|
||||
image_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.Resize(args.resolution, interpolation=interpolation),
|
||||
transforms.CenterCrop(args.resolution),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
@@ -747,7 +760,7 @@ def prepare_train_dataset(dataset, accelerator):
|
||||
|
||||
conditioning_image_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.Resize(args.resolution, interpolation=interpolation),
|
||||
transforms.CenterCrop(args.resolution),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
@@ -1085,8 +1098,6 @@ def main(args):
|
||||
return {"prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled_prompt_embeds, "text_ids": text_ids}
|
||||
|
||||
train_dataset = get_train_dataset(args, accelerator)
|
||||
text_encoders = [text_encoder_one, text_encoder_two]
|
||||
tokenizers = [tokenizer_one, tokenizer_two]
|
||||
compute_embeddings_fn = functools.partial(
|
||||
compute_embeddings,
|
||||
flux_controlnet_pipeline=flux_controlnet_pipeline,
|
||||
@@ -1103,7 +1114,8 @@ def main(args):
|
||||
compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint, batch_size=50
|
||||
)
|
||||
|
||||
del text_encoders, tokenizers, text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two
|
||||
text_encoder_one.to("cpu")
|
||||
text_encoder_two.to("cpu")
|
||||
free_memory()
|
||||
|
||||
# Then get the training dataset ready to be passed to the dataloader.
|
||||
|
||||
@@ -134,7 +134,25 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
|
||||
|
||||
for validation_prompt, validation_image in zip(validation_prompts, validation_images):
|
||||
validation_image = Image.open(validation_image).convert("RGB")
|
||||
validation_image = validation_image.resize((args.resolution, args.resolution))
|
||||
|
||||
try:
|
||||
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper())
|
||||
except (AttributeError, KeyError):
|
||||
supported_interpolation_modes = [
|
||||
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
|
||||
]
|
||||
raise ValueError(
|
||||
f"Interpolation mode {args.image_interpolation_mode} is not supported. "
|
||||
f"Please select one of the following: {', '.join(supported_interpolation_modes)}"
|
||||
)
|
||||
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(args.resolution, interpolation=interpolation),
|
||||
transforms.CenterCrop(args.resolution),
|
||||
]
|
||||
)
|
||||
validation_image = transform(validation_image)
|
||||
|
||||
images = []
|
||||
|
||||
@@ -587,6 +605,15 @@ def parse_args(input_args=None):
|
||||
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_interpolation_mode",
|
||||
type=str,
|
||||
default="lanczos",
|
||||
choices=[
|
||||
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
|
||||
],
|
||||
help="The image interpolation method to use for resizing images.",
|
||||
)
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
@@ -732,9 +759,20 @@ def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prom
|
||||
|
||||
|
||||
def prepare_train_dataset(dataset, accelerator):
|
||||
try:
|
||||
interpolation_mode = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper())
|
||||
except (AttributeError, KeyError):
|
||||
supported_interpolation_modes = [
|
||||
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
|
||||
]
|
||||
raise ValueError(
|
||||
f"Interpolation mode {args.image_interpolation_mode} is not supported. "
|
||||
f"Please select one of the following: {', '.join(supported_interpolation_modes)}"
|
||||
)
|
||||
|
||||
image_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.Resize(args.resolution, interpolation=interpolation_mode),
|
||||
transforms.CenterCrop(args.resolution),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
@@ -743,7 +781,7 @@ def prepare_train_dataset(dataset, accelerator):
|
||||
|
||||
conditioning_image_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.Resize(args.resolution, interpolation=interpolation_mode),
|
||||
transforms.CenterCrop(args.resolution),
|
||||
transforms.ToTensor(),
|
||||
]
|
||||
|
||||
@@ -0,0 +1,119 @@
|
||||
# DreamBooth training example for HiDream Image
|
||||
|
||||
[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject.
|
||||
|
||||
The `train_dreambooth_lora_hidream.py` script shows how to implement the training procedure with [LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) and adapt it for [HiDream Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/).
|
||||
|
||||
|
||||
This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
|
||||
|
||||
## Running locally with PyTorch
|
||||
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
**Important**
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Then cd in the `examples/dreambooth` folder and run
|
||||
```bash
|
||||
pip install -r requirements_hidream.txt
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
Or for a default accelerate configuration without answering questions about your environment
|
||||
|
||||
```bash
|
||||
accelerate config default
|
||||
```
|
||||
|
||||
Or if your environment doesn't support an interactive shell (e.g., a notebook)
|
||||
|
||||
```python
|
||||
from accelerate.utils import write_basic_config
|
||||
write_basic_config()
|
||||
```
|
||||
|
||||
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
|
||||
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.14.0` installed in your environment.
|
||||
|
||||
|
||||
### 3d icon example
|
||||
|
||||
For this example we will use some 3d icon images: https://huggingface.co/datasets/linoyts/3d_icon.
|
||||
|
||||
This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
|
||||
|
||||
Now, we can launch training using:
|
||||
> [!NOTE]
|
||||
> The following training configuration prioritizes lower memory consumption by using gradient checkpointing,
|
||||
> 8-bit Adam optimizer, latent caching, offloading, no validation.
|
||||
> all text embeddings are pre-computed to save memory.
|
||||
```bash
|
||||
export MODEL_NAME="HiDream-ai/HiDream-I1-Dev"
|
||||
export INSTANCE_DIR="linoyts/3d_icon"
|
||||
export OUTPUT_DIR="trained-hidream-lora"
|
||||
|
||||
accelerate launch train_dreambooth_lora_hidream.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--dataset_name=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--mixed_precision="bf16" \
|
||||
--instance_prompt="3d icon" \
|
||||
--caption_column="prompt"\
|
||||
--validation_prompt="a 3dicon, a llama eating ramen" \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--use_8bit_adam \
|
||||
--rank=8 \
|
||||
--learning_rate=2e-4 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant_with_warmup" \
|
||||
--lr_warmup_steps=100 \
|
||||
--max_train_steps=1000 \
|
||||
--cache_latents\
|
||||
--gradient_checkpointing \
|
||||
--validation_epochs=25 \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
For using `push_to_hub`, make you're logged into your Hugging Face account:
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
To better track our training experiments, we're using the following flags in the command above:
|
||||
|
||||
* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login <your_api_key>` before training if you haven't done it before.
|
||||
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
|
||||
|
||||
## Notes
|
||||
|
||||
Additionally, we welcome you to explore the following CLI arguments:
|
||||
|
||||
* `--lora_layers`: The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. E.g. - "to_k,to_q,to_v" will result in lora training of attention layers only.
|
||||
* `--rank`: The rank of the LoRA layers. The higher the rank, the more parameters are trained. The default is 16.
|
||||
|
||||
We provide several options for optimizing memory optimization:
|
||||
|
||||
* `--offload`: When enabled, we will offload the text encoder and VAE to CPU, when they are not used.
|
||||
* `cache_latents`: When enabled, we will pre-compute the latents from the input images with the VAE and remove the VAE from memory once done.
|
||||
* `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library.
|
||||
|
||||
Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/) of the `HiDreamImagePipeline` to know more about the model.
|
||||
@@ -0,0 +1,8 @@
|
||||
accelerate>=1.4.0
|
||||
torchvision
|
||||
transformers>=4.50.0
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
peft>=0.14.0
|
||||
sentencepiece
|
||||
@@ -0,0 +1,220 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import safetensors
|
||||
|
||||
|
||||
sys.path.append("..")
|
||||
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
class DreamBoothLoRAHiDreamImage(ExamplesTestsAccelerate):
|
||||
instance_data_dir = "docs/source/en/imgs"
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-hidream-i1-pipe"
|
||||
text_encoder_4_path = "hf-internal-testing/tiny-random-LlamaForCausalLM"
|
||||
tokenizer_4_path = "hf-internal-testing/tiny-random-LlamaForCausalLM"
|
||||
script_path = "examples/dreambooth/train_dreambooth_lora_hidream.py"
|
||||
transformer_layer_type = "double_stream_blocks.0.block.attn1.to_k"
|
||||
|
||||
def test_dreambooth_lora_hidream(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path}
|
||||
--pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--resolution 32
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--max_sequence_length 16
|
||||
""".split()
|
||||
|
||||
test_args.extend(["--instance_prompt", ""])
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names.
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_latent_caching(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path}
|
||||
--pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--resolution 32
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--cache_latents
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--max_sequence_length 16
|
||||
""".split()
|
||||
|
||||
test_args.extend(["--instance_prompt", ""])
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names.
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_layers(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path}
|
||||
--pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--resolution 32
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--cache_latents
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lora_layers {self.transformer_layer_type}
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--max_sequence_length 16
|
||||
""".split()
|
||||
|
||||
test_args.extend(["--instance_prompt", ""])
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names. In this test, we only params of
|
||||
# `self.transformer_layer_type` should be in the state dict.
|
||||
starts_with_transformer = all(self.transformer_layer_type in key for key in lora_state_dict)
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_hidream_checkpointing_checkpoints_total_limit(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path}
|
||||
--pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--resolution=32
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=6
|
||||
--checkpoints_total_limit=2
|
||||
--checkpointing_steps=2
|
||||
--max_sequence_length 16
|
||||
""".split()
|
||||
|
||||
test_args.extend(["--instance_prompt", ""])
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_dreambooth_lora_hidream_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path}
|
||||
--pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--resolution=32
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=4
|
||||
--checkpointing_steps=2
|
||||
--max_sequence_length 16
|
||||
""".split()
|
||||
|
||||
test_args.extend(["--instance_prompt", ""])
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
|
||||
|
||||
resume_run_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path}
|
||||
--pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--resolution=32
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=8
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-4
|
||||
--checkpoints_total_limit=2
|
||||
--max_sequence_length 16
|
||||
""".split()
|
||||
|
||||
resume_run_args.extend(["--instance_prompt", ""])
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
|
||||
@@ -618,6 +618,15 @@ def parse_args(input_args=None):
|
||||
),
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument(
|
||||
"--image_interpolation_mode",
|
||||
type=str,
|
||||
default="lanczos",
|
||||
choices=[
|
||||
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
|
||||
],
|
||||
help="The image interpolation method to use for resizing images.",
|
||||
)
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
@@ -737,7 +746,10 @@ class DreamBoothDataset(Dataset):
|
||||
self.instance_images.extend(itertools.repeat(img, repeats))
|
||||
|
||||
self.pixel_values = []
|
||||
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
|
||||
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
|
||||
if interpolation is None:
|
||||
raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
|
||||
train_resize = transforms.Resize(size, interpolation=interpolation)
|
||||
train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
|
||||
train_flip = transforms.RandomHorizontalFlip(p=1.0)
|
||||
train_transforms = transforms.Compose(
|
||||
@@ -1622,7 +1634,7 @@ def main(args):
|
||||
# Predict the noise residual
|
||||
model_pred = transformer(
|
||||
hidden_states=packed_noisy_model_input,
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
timestep=timesteps / 1000,
|
||||
guidance=guidance,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
|
||||
@@ -524,6 +524,15 @@ def parse_args(input_args=None):
|
||||
default=4,
|
||||
help=("The dimension of the LoRA update matrices."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_interpolation_mode",
|
||||
type=str,
|
||||
default="lanczos",
|
||||
choices=[
|
||||
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
|
||||
],
|
||||
help="The image interpolation method to use for resizing images.",
|
||||
)
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
@@ -601,9 +610,13 @@ class DreamBoothDataset(Dataset):
|
||||
else:
|
||||
self.class_data_root = None
|
||||
|
||||
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
|
||||
if interpolation is None:
|
||||
raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
|
||||
|
||||
self.image_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.Resize(size, interpolation=interpolation),
|
||||
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
|
||||
@@ -1749,7 +1749,7 @@ def main(args):
|
||||
# Predict the noise residual
|
||||
model_pred = transformer(
|
||||
hidden_states=packed_noisy_model_input,
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
timestep=timesteps / 1000,
|
||||
guidance=guidance,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
+1
-1
@@ -1088,7 +1088,7 @@ def main(args):
|
||||
text_ids = batch["text_ids"].to(device=accelerator.device, dtype=weight_dtype)
|
||||
model_pred = transformer(
|
||||
hidden_states=packed_noisy_model_input,
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
timestep=timesteps / 1000,
|
||||
guidance=guidance,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
|
||||
@@ -499,6 +499,15 @@ def parse_args():
|
||||
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_interpolation_mode",
|
||||
type=str,
|
||||
default="lanczos",
|
||||
choices=[
|
||||
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
|
||||
],
|
||||
help="The image interpolation method to use for resizing images.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
@@ -787,10 +796,17 @@ def main():
|
||||
)
|
||||
return inputs.input_ids
|
||||
|
||||
# Preprocessing the datasets.
|
||||
# Get the specified interpolation method from the args
|
||||
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
|
||||
|
||||
# Raise an error if the interpolation method is invalid
|
||||
if interpolation is None:
|
||||
raise ValueError(f"Unsupported interpolation mode {args.image_interpolation_mode}.")
|
||||
|
||||
# Data preprocessing transformations
|
||||
train_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.Resize(args.resolution, interpolation=interpolation), # Use dynamic interpolation method
|
||||
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
|
||||
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
|
||||
transforms.ToTensor(),
|
||||
|
||||
@@ -418,6 +418,15 @@ def parse_args():
|
||||
default=4,
|
||||
help=("The dimension of the LoRA update matrices."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_interpolation_mode",
|
||||
type=str,
|
||||
default="lanczos",
|
||||
choices=[
|
||||
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
|
||||
],
|
||||
help="The image interpolation method to use for resizing images.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
@@ -649,10 +658,17 @@ def main():
|
||||
)
|
||||
return inputs.input_ids
|
||||
|
||||
# Preprocessing the datasets.
|
||||
# Get the specified interpolation method from the args
|
||||
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
|
||||
|
||||
# Raise an error if the interpolation method is invalid
|
||||
if interpolation is None:
|
||||
raise ValueError(f"Unsupported interpolation mode {args.image_interpolation_mode}.")
|
||||
|
||||
# Data preprocessing transformations
|
||||
train_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.Resize(args.resolution, interpolation=interpolation), # Use dynamic interpolation method
|
||||
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
|
||||
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
|
||||
transforms.ToTensor(),
|
||||
|
||||
@@ -34,6 +34,7 @@ from .utils import (
|
||||
|
||||
_import_structure = {
|
||||
"configuration_utils": ["ConfigMixin"],
|
||||
"guiders": [],
|
||||
"hooks": [],
|
||||
"loaders": ["FromOriginalModelMixin"],
|
||||
"models": [],
|
||||
@@ -130,12 +131,26 @@ except OptionalDependencyNotAvailable:
|
||||
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
|
||||
|
||||
else:
|
||||
_import_structure["guiders"].extend(
|
||||
[
|
||||
"AdaptiveProjectedGuidance",
|
||||
"AutoGuidance",
|
||||
"ClassifierFreeGuidance",
|
||||
"ClassifierFreeZeroStarGuidance",
|
||||
"SkipLayerGuidance",
|
||||
"SmoothedEnergyGuidance",
|
||||
"TangentialClassifierFreeGuidance",
|
||||
]
|
||||
)
|
||||
_import_structure["hooks"].extend(
|
||||
[
|
||||
"FasterCacheConfig",
|
||||
"HookRegistry",
|
||||
"PyramidAttentionBroadcastConfig",
|
||||
"LayerSkipConfig",
|
||||
"SmoothedEnergyGuidanceConfig",
|
||||
"apply_faster_cache",
|
||||
"apply_layer_skip",
|
||||
"apply_pyramid_attention_broadcast",
|
||||
]
|
||||
)
|
||||
@@ -239,6 +254,7 @@ else:
|
||||
"KarrasVePipeline",
|
||||
"LDMPipeline",
|
||||
"LDMSuperResolutionPipeline",
|
||||
"ModularLoader",
|
||||
"PNDMPipeline",
|
||||
"RePaintPipeline",
|
||||
"ScoreSdeVePipeline",
|
||||
@@ -493,10 +509,12 @@ else:
|
||||
"StableDiffusionXLImg2ImgPipeline",
|
||||
"StableDiffusionXLInpaintPipeline",
|
||||
"StableDiffusionXLInstructPix2PixPipeline",
|
||||
"StableDiffusionXLModularLoader",
|
||||
"StableDiffusionXLPAGImg2ImgPipeline",
|
||||
"StableDiffusionXLPAGInpaintPipeline",
|
||||
"StableDiffusionXLPAGPipeline",
|
||||
"StableDiffusionXLPipeline",
|
||||
"StableDiffusionXLAutoPipeline",
|
||||
"StableUnCLIPImg2ImgPipeline",
|
||||
"StableUnCLIPPipeline",
|
||||
"StableVideoDiffusionPipeline",
|
||||
@@ -728,11 +746,23 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_pt_objects import * # noqa F403
|
||||
else:
|
||||
from .guiders import (
|
||||
AdaptiveProjectedGuidance,
|
||||
AutoGuidance,
|
||||
ClassifierFreeGuidance,
|
||||
ClassifierFreeZeroStarGuidance,
|
||||
SkipLayerGuidance,
|
||||
SmoothedEnergyGuidance,
|
||||
TangentialClassifierFreeGuidance,
|
||||
)
|
||||
from .hooks import (
|
||||
FasterCacheConfig,
|
||||
HookRegistry,
|
||||
LayerSkipConfig,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
SmoothedEnergyGuidanceConfig,
|
||||
apply_faster_cache,
|
||||
apply_layer_skip,
|
||||
apply_pyramid_attention_broadcast,
|
||||
)
|
||||
from .models import (
|
||||
@@ -834,6 +864,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
KarrasVePipeline,
|
||||
LDMPipeline,
|
||||
LDMSuperResolutionPipeline,
|
||||
ModularLoader,
|
||||
PNDMPipeline,
|
||||
RePaintPipeline,
|
||||
ScoreSdeVePipeline,
|
||||
@@ -1054,6 +1085,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionSAGPipeline,
|
||||
StableDiffusionUpscalePipeline,
|
||||
StableDiffusionXLAdapterPipeline,
|
||||
StableDiffusionXLAutoPipeline,
|
||||
StableDiffusionXLControlNetImg2ImgPipeline,
|
||||
StableDiffusionXLControlNetInpaintPipeline,
|
||||
StableDiffusionXLControlNetPAGImg2ImgPipeline,
|
||||
@@ -1066,6 +1098,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
StableDiffusionXLInstructPix2PixPipeline,
|
||||
StableDiffusionXLModularLoader,
|
||||
StableDiffusionXLPAGImg2ImgPipeline,
|
||||
StableDiffusionXLPAGInpaintPipeline,
|
||||
StableDiffusionXLPAGPipeline,
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Union
|
||||
|
||||
from ..utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .adaptive_projected_guidance import AdaptiveProjectedGuidance
|
||||
from .auto_guidance import AutoGuidance
|
||||
from .classifier_free_guidance import ClassifierFreeGuidance
|
||||
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
|
||||
from .skip_layer_guidance import SkipLayerGuidance
|
||||
from .smoothed_energy_guidance import SmoothedEnergyGuidance
|
||||
from .tangential_classifier_free_guidance import TangentialClassifierFreeGuidance
|
||||
|
||||
GuiderType = Union[AdaptiveProjectedGuidance, AutoGuidance, ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, SkipLayerGuidance, SmoothedEnergyGuidance, TangentialClassifierFreeGuidance]
|
||||
@@ -0,0 +1,181 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
class AdaptiveProjectedGuidance(BaseGuidance):
|
||||
"""
|
||||
Adaptive Projected Guidance (APG): https://huggingface.co/papers/2410.02416
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
adaptive_projected_guidance_momentum (`float`, defaults to `None`):
|
||||
The momentum parameter for the adaptive projected guidance. Disabled if set to `None`.
|
||||
adaptive_projected_guidance_rescale (`float`, defaults to `15.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float`, defaults to `0.0`):
|
||||
The fraction of the total number of denoising steps after which guidance starts.
|
||||
stop (`float`, defaults to `1.0`):
|
||||
The fraction of the total number of denoising steps after which guidance stops.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
adaptive_projected_guidance_momentum: Optional[float] = None,
|
||||
adaptive_projected_guidance_rescale: float = 15.0,
|
||||
eta: float = 1.0,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
|
||||
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
|
||||
self.eta = eta
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
self.momentum_buffer = None
|
||||
|
||||
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
||||
if self._step == 0:
|
||||
if self.adaptive_projected_guidance_momentum is not None:
|
||||
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if not self._is_apg_enabled():
|
||||
pred = pred_cond
|
||||
else:
|
||||
pred = normalized_guidance(
|
||||
pred_cond,
|
||||
pred_uncond,
|
||||
self.guidance_scale,
|
||||
self.momentum_buffer,
|
||||
self.eta,
|
||||
self.adaptive_projected_guidance_rescale,
|
||||
self.use_original_formulation,
|
||||
)
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred, {}
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._count_prepared == 1
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_apg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_apg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
|
||||
class MomentumBuffer:
|
||||
def __init__(self, momentum: float):
|
||||
self.momentum = momentum
|
||||
self.running_average = 0
|
||||
|
||||
def update(self, update_value: torch.Tensor):
|
||||
new_average = self.momentum * self.running_average
|
||||
self.running_average = update_value + new_average
|
||||
|
||||
|
||||
def normalized_guidance(
|
||||
pred_cond: torch.Tensor,
|
||||
pred_uncond: torch.Tensor,
|
||||
guidance_scale: float,
|
||||
momentum_buffer: Optional[MomentumBuffer] = None,
|
||||
eta: float = 1.0,
|
||||
norm_threshold: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
):
|
||||
diff = pred_cond - pred_uncond
|
||||
dim = [-i for i in range(1, len(diff.shape))]
|
||||
|
||||
if momentum_buffer is not None:
|
||||
momentum_buffer.update(diff)
|
||||
diff = momentum_buffer.running_average
|
||||
|
||||
if norm_threshold > 0:
|
||||
ones = torch.ones_like(diff)
|
||||
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
|
||||
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
|
||||
diff = diff * scale_factor
|
||||
|
||||
v0, v1 = diff.double(), pred_cond.double()
|
||||
v1 = torch.nn.functional.normalize(v1, dim=dim)
|
||||
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
|
||||
v0_orthogonal = v0 - v0_parallel
|
||||
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
|
||||
normalized_update = diff_orthogonal + eta * diff_parallel
|
||||
|
||||
pred = pred_cond if use_original_formulation else pred_uncond
|
||||
pred = pred + guidance_scale * normalized_update
|
||||
|
||||
return pred
|
||||
@@ -0,0 +1,174 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..hooks import HookRegistry, LayerSkipConfig
|
||||
from ..hooks.layer_skip import _apply_layer_skip_hook
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
class AutoGuidance(BaseGuidance):
|
||||
"""
|
||||
AutoGuidance: https://huggingface.co/papers/2406.02507
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
auto_guidance_layers (`int` or `List[int]`, *optional*):
|
||||
The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
|
||||
provided, `skip_layer_config` must be provided.
|
||||
auto_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
|
||||
The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
|
||||
`LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
|
||||
dropout (`float`, *optional*):
|
||||
The dropout probability for autoguidance on the enabled skip layers (either with `auto_guidance_layers` or
|
||||
`auto_guidance_config`). If not provided, the dropout probability will be set to 1.0.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float`, defaults to `0.0`):
|
||||
The fraction of the total number of denoising steps after which guidance starts.
|
||||
stop (`float`, defaults to `1.0`):
|
||||
The fraction of the total number of denoising steps after which guidance stops.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
auto_guidance_layers: Optional[Union[int, List[int]]] = None,
|
||||
auto_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None,
|
||||
dropout: Optional[float] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.auto_guidance_layers = auto_guidance_layers
|
||||
self.auto_guidance_config = auto_guidance_config
|
||||
self.dropout = dropout
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
if auto_guidance_layers is None and auto_guidance_config is None:
|
||||
raise ValueError(
|
||||
"Either `auto_guidance_layers` or `auto_guidance_config` must be provided to enable Skip Layer Guidance."
|
||||
)
|
||||
if auto_guidance_layers is not None and auto_guidance_config is not None:
|
||||
raise ValueError("Only one of `auto_guidance_layers` or `auto_guidance_config` can be provided.")
|
||||
if (dropout is None and auto_guidance_layers is not None) or (dropout is not None and auto_guidance_layers is None):
|
||||
raise ValueError("`dropout` must be provided if `auto_guidance_layers` is provided.")
|
||||
|
||||
if auto_guidance_layers is not None:
|
||||
if isinstance(auto_guidance_layers, int):
|
||||
auto_guidance_layers = [auto_guidance_layers]
|
||||
if not isinstance(auto_guidance_layers, list):
|
||||
raise ValueError(
|
||||
f"Expected `auto_guidance_layers` to be an int or a list of ints, but got {type(auto_guidance_layers)}."
|
||||
)
|
||||
auto_guidance_config = [LayerSkipConfig(layer, fqn="auto", dropout=dropout) for layer in auto_guidance_layers]
|
||||
|
||||
if isinstance(auto_guidance_config, LayerSkipConfig):
|
||||
auto_guidance_config = [auto_guidance_config]
|
||||
|
||||
if not isinstance(auto_guidance_config, list):
|
||||
raise ValueError(
|
||||
f"Expected `auto_guidance_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(auto_guidance_config)}."
|
||||
)
|
||||
|
||||
self.auto_guidance_config = auto_guidance_config
|
||||
self._auto_guidance_hook_names = [f"AutoGuidance_{i}" for i in range(len(self.auto_guidance_config))]
|
||||
|
||||
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
||||
self._count_prepared += 1
|
||||
if self._is_ag_enabled() and self.is_unconditional:
|
||||
for name, config in zip(self._auto_guidance_hook_names, self.auto_guidance_config):
|
||||
_apply_layer_skip_hook(denoiser, config, name=name)
|
||||
|
||||
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
|
||||
if self._is_ag_enabled() and self.is_unconditional:
|
||||
for name in self._auto_guidance_hook_names:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
||||
registry.remove_hook(name, recurse=True)
|
||||
|
||||
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if not self._is_ag_enabled():
|
||||
pred = pred_cond
|
||||
else:
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred, {}
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._count_prepared == 1
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_ag_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_ag_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
@@ -0,0 +1,129 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
class ClassifierFreeGuidance(BaseGuidance):
|
||||
"""
|
||||
Classifier-free guidance (CFG): https://huggingface.co/papers/2207.12598
|
||||
|
||||
CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by
|
||||
jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during
|
||||
inference. This allows the model to tradeoff between generation quality and sample diversity.
|
||||
The original paper proposes scaling and shifting the conditional distribution based on the difference between
|
||||
conditional and unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)]
|
||||
|
||||
Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen
|
||||
paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in
|
||||
theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
|
||||
|
||||
The intution behind the original formulation can be thought of as moving the conditional distribution estimates
|
||||
further away from the unconditional distribution estimates, while the diffusers-native implementation can be
|
||||
thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of
|
||||
the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.)
|
||||
|
||||
The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
|
||||
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float`, defaults to `0.0`):
|
||||
The fraction of the total number of denoising steps after which guidance starts.
|
||||
stop (`float`, defaults to `1.0`):
|
||||
The fraction of the total number of denoising steps after which guidance stops.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
|
||||
def __init__(
|
||||
self, guidance_scale: float = 7.5, guidance_rescale: float = 0.0, use_original_formulation: bool = False, start: float = 0.0, stop: float = 1.0
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if not self._is_cfg_enabled():
|
||||
pred = pred_cond
|
||||
else:
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred, {}
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._count_prepared == 1
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_cfg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
@@ -0,0 +1,145 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
class ClassifierFreeZeroStarGuidance(BaseGuidance):
|
||||
"""
|
||||
Classifier-free Zero* (CFG-Zero*): https://huggingface.co/papers/2503.18886
|
||||
|
||||
This is an implementation of the Classifier-Free Zero* guidance technique, which is a variant of classifier-free
|
||||
guidance. It proposes zero initialization of the noise predictions for the first few steps of the diffusion
|
||||
process, and also introduces an optimal rescaling factor for the noise predictions, which can help in improving the
|
||||
quality of generated images.
|
||||
|
||||
The authors of the paper suggest setting zero initialization in the first 4% of the inference steps.
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
zero_init_steps (`int`, defaults to `1`):
|
||||
The number of inference steps for which the noise predictions are zeroed out (see Section 4.2).
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float`, defaults to `0.01`):
|
||||
The fraction of the total number of denoising steps after which guidance starts.
|
||||
stop (`float`, defaults to `0.2`):
|
||||
The fraction of the total number of denoising steps after which guidance stops.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
zero_init_steps: int = 1,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.zero_init_steps = zero_init_steps
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if self._step < self.zero_init_steps:
|
||||
pred = torch.zeros_like(pred_cond)
|
||||
elif not self._is_cfg_enabled():
|
||||
pred = pred_cond
|
||||
else:
|
||||
pred_cond_flat = pred_cond.flatten(1)
|
||||
pred_uncond_flat = pred_uncond.flatten(1)
|
||||
alpha = cfg_zero_star_scale(pred_cond_flat, pred_uncond_flat)
|
||||
alpha = alpha.view(-1, *(1,) * (len(pred_cond.shape) - 1))
|
||||
pred_uncond = pred_uncond * alpha
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred, {}
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._count_prepared == 1
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_cfg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
|
||||
def cfg_zero_star_scale(cond: torch.Tensor, uncond: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
|
||||
cond_dtype = cond.dtype
|
||||
cond = cond.float()
|
||||
uncond = uncond.float()
|
||||
dot_product = torch.sum(cond * uncond, dim=1, keepdim=True)
|
||||
squared_norm = torch.sum(uncond**2, dim=1, keepdim=True) + eps
|
||||
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
|
||||
scale = dot_product / squared_norm
|
||||
return scale.to(dtype=cond_dtype)
|
||||
@@ -0,0 +1,215 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import get_logger
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class BaseGuidance:
|
||||
r"""Base class providing the skeleton for implementing guidance techniques."""
|
||||
|
||||
_input_predictions = None
|
||||
_identifier_key = "__guidance_identifier__"
|
||||
|
||||
def __init__(self, start: float = 0.0, stop: float = 1.0):
|
||||
self._start = start
|
||||
self._stop = stop
|
||||
self._step: int = None
|
||||
self._num_inference_steps: int = None
|
||||
self._timestep: torch.LongTensor = None
|
||||
self._count_prepared = 0
|
||||
self._input_fields: Dict[str, Union[str, Tuple[str, str]]] = None
|
||||
self._enabled = True
|
||||
|
||||
if not (0.0 <= start < 1.0):
|
||||
raise ValueError(
|
||||
f"Expected `start` to be between 0.0 and 1.0, but got {start}."
|
||||
)
|
||||
if not (start <= stop <= 1.0):
|
||||
raise ValueError(
|
||||
f"Expected `stop` to be between {start} and 1.0, but got {stop}."
|
||||
)
|
||||
|
||||
if self._input_predictions is None or not isinstance(self._input_predictions, list):
|
||||
raise ValueError(
|
||||
"`_input_predictions` must be a list of required prediction names for the guidance technique."
|
||||
)
|
||||
|
||||
def disable(self):
|
||||
self._enabled = False
|
||||
|
||||
def enable(self):
|
||||
self._enabled = True
|
||||
|
||||
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
|
||||
self._step = step
|
||||
self._num_inference_steps = num_inference_steps
|
||||
self._timestep = timestep
|
||||
self._count_prepared = 0
|
||||
|
||||
def set_input_fields(self, **kwargs: Dict[str, Union[str, Tuple[str, str]]]) -> None:
|
||||
"""
|
||||
Set the input fields for the guidance technique. The input fields are used to specify the names of the
|
||||
returned attributes containing the prepared data after `prepare_inputs` is called. The prepared data is
|
||||
obtained from the values of the provided keyword arguments to this method.
|
||||
|
||||
Args:
|
||||
**kwargs (`Dict[str, Union[str, Tuple[str, str]]]`):
|
||||
A dictionary where the keys are the names of the fields that will be used to store the data once
|
||||
it is prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2,
|
||||
which is used to look up the required data provided for preparation.
|
||||
|
||||
If a string is provided, it will be used as the conditional data (or unconditional if used with
|
||||
a guidance method that requires it). If a tuple of length 2 is provided, the first element must
|
||||
be the conditional data identifier and the second element must be the unconditional data identifier
|
||||
or None.
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
data = {"prompt_embeds": <some tensor>, "negative_prompt_embeds": <some tensor>, "latents": <some tensor>}
|
||||
|
||||
BaseGuidance.set_input_fields(
|
||||
latents="latents",
|
||||
prompt_embeds=("prompt_embeds", "negative_prompt_embeds"),
|
||||
)
|
||||
```
|
||||
"""
|
||||
for key, value in kwargs.items():
|
||||
is_string = isinstance(value, str)
|
||||
is_tuple_of_str_with_len_2 = isinstance(value, tuple) and len(value) == 2 and all(isinstance(v, str) for v in value)
|
||||
if not (is_string or is_tuple_of_str_with_len_2):
|
||||
raise ValueError(
|
||||
f"Expected `set_input_fields` to be called with a string or a tuple of string with length 2, but got {type(value)} for key {key}."
|
||||
)
|
||||
self._input_fields = kwargs
|
||||
|
||||
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
||||
"""
|
||||
Prepares the models for the guidance technique on a given batch of data. This method should be overridden in
|
||||
subclasses to implement specific model preparation logic.
|
||||
"""
|
||||
self._count_prepared += 1
|
||||
|
||||
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
|
||||
"""
|
||||
Cleans up the models for the guidance technique after a given batch of data. This method should be overridden in
|
||||
subclasses to implement specific model cleanup logic. It is useful for removing any hooks or other stateful
|
||||
modifications made during `prepare_models`.
|
||||
"""
|
||||
pass
|
||||
|
||||
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
||||
raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")
|
||||
|
||||
def __call__(self, data: List["BlockState"]) -> Any:
|
||||
if not all(hasattr(d, "noise_pred") for d in data):
|
||||
raise ValueError("Expected all data to have `noise_pred` attribute.")
|
||||
if len(data) != self.num_conditions:
|
||||
raise ValueError(
|
||||
f"Expected {self.num_conditions} data items, but got {len(data)}. Please check the input data."
|
||||
)
|
||||
forward_inputs = {getattr(d, self._identifier_key): d.noise_pred for d in data}
|
||||
return self.forward(**forward_inputs)
|
||||
|
||||
def forward(self, *args, **kwargs) -> Any:
|
||||
raise NotImplementedError("BaseGuidance::forward must be implemented in subclasses.")
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.")
|
||||
|
||||
@property
|
||||
def is_unconditional(self) -> bool:
|
||||
return not self.is_conditional
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
raise NotImplementedError("BaseGuidance::num_conditions must be implemented in subclasses.")
|
||||
|
||||
@classmethod
|
||||
def _prepare_batch(cls, input_fields: Dict[str, Union[str, Tuple[str, str]]], data: "BlockState", tuple_index: int, identifier: str) -> "BlockState":
|
||||
"""
|
||||
Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of
|
||||
the `BaseGuidance` class. It prepares the batch based on the provided tuple index.
|
||||
|
||||
Args:
|
||||
input_fields (`Dict[str, Union[str, Tuple[str, str]]]`):
|
||||
A dictionary where the keys are the names of the fields that will be used to store the data once
|
||||
it is prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2,
|
||||
which is used to look up the required data provided for preparation.
|
||||
If a string is provided, it will be used as the conditional data (or unconditional if used with
|
||||
a guidance method that requires it). If a tuple of length 2 is provided, the first element must
|
||||
be the conditional data identifier and the second element must be the unconditional data identifier
|
||||
or None.
|
||||
data (`BlockState`):
|
||||
The input data to be prepared.
|
||||
tuple_index (`int`):
|
||||
The index to use when accessing input fields that are tuples.
|
||||
|
||||
Returns:
|
||||
`BlockState`: The prepared batch of data.
|
||||
"""
|
||||
from ..pipelines.modular_pipeline import BlockState
|
||||
|
||||
if input_fields is None:
|
||||
raise ValueError("Input fields have not been set. Please call `set_input_fields` before preparing inputs.")
|
||||
data_batch = {}
|
||||
for key, value in input_fields.items():
|
||||
try:
|
||||
if isinstance(value, str):
|
||||
data_batch[key] = getattr(data, value)
|
||||
elif isinstance(value, tuple):
|
||||
data_batch[key] = getattr(data, value[tuple_index])
|
||||
else:
|
||||
# We've already checked that value is a string or a tuple of strings with length 2
|
||||
pass
|
||||
except AttributeError:
|
||||
raise ValueError(f"Expected `data` to have attribute(s) {value}, but it does not. Please check the input data.")
|
||||
data_batch[cls._identifier_key] = identifier
|
||||
return BlockState(**data_batch)
|
||||
|
||||
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
r"""
|
||||
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
|
||||
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||
Args:
|
||||
noise_cfg (`torch.Tensor`):
|
||||
The predicted noise tensor for the guided diffusion process.
|
||||
noise_pred_text (`torch.Tensor`):
|
||||
The predicted noise tensor for the text-guided diffusion process.
|
||||
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
||||
A rescale factor applied to the noise predictions.
|
||||
Returns:
|
||||
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
|
||||
"""
|
||||
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
||||
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
||||
# rescale the results from guidance (fixes overexposure)
|
||||
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
||||
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
||||
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
||||
return noise_cfg
|
||||
@@ -0,0 +1,248 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..hooks import HookRegistry, LayerSkipConfig
|
||||
from ..hooks.layer_skip import _apply_layer_skip_hook
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
class SkipLayerGuidance(BaseGuidance):
|
||||
"""
|
||||
Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5
|
||||
|
||||
Spatio-Temporal Guidance (STG): https://huggingface.co/papers/2411.18664
|
||||
|
||||
SLG was introduced by StabilityAI for improving structure and anotomy coherence in generated images. It works by
|
||||
skipping the forward pass of specified transformer blocks during the denoising process on an additional conditional
|
||||
batch of data, apart from the conditional and unconditional batches already used in CFG
|
||||
([~guiders.classifier_free_guidance.ClassifierFreeGuidance]), and then scaling and shifting the CFG predictions
|
||||
based on the difference between conditional without skipping and conditional with skipping predictions.
|
||||
|
||||
The intution behind SLG can be thought of as moving the CFG predicted distribution estimates further away from
|
||||
worse versions of the conditional distribution estimates (because skipping layers is equivalent to using a worse
|
||||
version of the model for the conditional prediction).
|
||||
|
||||
STG is an improvement and follow-up work combining ideas from SLG, PAG and similar techniques for improving
|
||||
generation quality in video diffusion models.
|
||||
|
||||
Additional reading:
|
||||
- [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507)
|
||||
|
||||
The values for `skip_layer_guidance_scale`, `skip_layer_guidance_start`, and `skip_layer_guidance_stop` are
|
||||
defaulted to the recommendations by StabilityAI for Stable Diffusion 3.5 Medium.
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
skip_layer_guidance_scale (`float`, defaults to `2.8`):
|
||||
The scale parameter for skip layer guidance. Anatomy and structure coherence may improve with higher
|
||||
values, but it may also lead to overexposure and saturation.
|
||||
skip_layer_guidance_start (`float`, defaults to `0.01`):
|
||||
The fraction of the total number of denoising steps after which skip layer guidance starts.
|
||||
skip_layer_guidance_stop (`float`, defaults to `0.2`):
|
||||
The fraction of the total number of denoising steps after which skip layer guidance stops.
|
||||
skip_layer_guidance_layers (`int` or `List[int]`, *optional*):
|
||||
The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
|
||||
provided, `skip_layer_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion
|
||||
3.5 Medium.
|
||||
skip_layer_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
|
||||
The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
|
||||
`LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float`, defaults to `0.01`):
|
||||
The fraction of the total number of denoising steps after which guidance starts.
|
||||
stop (`float`, defaults to `0.2`):
|
||||
The fraction of the total number of denoising steps after which guidance stops.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
skip_layer_guidance_scale: float = 2.8,
|
||||
skip_layer_guidance_start: float = 0.01,
|
||||
skip_layer_guidance_stop: float = 0.2,
|
||||
skip_layer_guidance_layers: Optional[Union[int, List[int]]] = None,
|
||||
skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.skip_layer_guidance_scale = skip_layer_guidance_scale
|
||||
self.skip_layer_guidance_start = skip_layer_guidance_start
|
||||
self.skip_layer_guidance_stop = skip_layer_guidance_stop
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
if not (0.0 <= skip_layer_guidance_start < 1.0):
|
||||
raise ValueError(
|
||||
f"Expected `skip_layer_guidance_start` to be between 0.0 and 1.0, but got {skip_layer_guidance_start}."
|
||||
)
|
||||
if not (skip_layer_guidance_start <= skip_layer_guidance_stop <= 1.0):
|
||||
raise ValueError(
|
||||
f"Expected `skip_layer_guidance_stop` to be between 0.0 and 1.0, but got {skip_layer_guidance_stop}."
|
||||
)
|
||||
|
||||
if skip_layer_guidance_layers is None and skip_layer_config is None:
|
||||
raise ValueError(
|
||||
"Either `skip_layer_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance."
|
||||
)
|
||||
if skip_layer_guidance_layers is not None and skip_layer_config is not None:
|
||||
raise ValueError("Only one of `skip_layer_guidance_layers` or `skip_layer_config` can be provided.")
|
||||
|
||||
if skip_layer_guidance_layers is not None:
|
||||
if isinstance(skip_layer_guidance_layers, int):
|
||||
skip_layer_guidance_layers = [skip_layer_guidance_layers]
|
||||
if not isinstance(skip_layer_guidance_layers, list):
|
||||
raise ValueError(
|
||||
f"Expected `skip_layer_guidance_layers` to be an int or a list of ints, but got {type(skip_layer_guidance_layers)}."
|
||||
)
|
||||
skip_layer_config = [LayerSkipConfig(layer, fqn="auto") for layer in skip_layer_guidance_layers]
|
||||
|
||||
if isinstance(skip_layer_config, LayerSkipConfig):
|
||||
skip_layer_config = [skip_layer_config]
|
||||
|
||||
if not isinstance(skip_layer_config, list):
|
||||
raise ValueError(
|
||||
f"Expected `skip_layer_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(skip_layer_config)}."
|
||||
)
|
||||
|
||||
self.skip_layer_config = skip_layer_config
|
||||
self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]
|
||||
|
||||
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
||||
self._count_prepared += 1
|
||||
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
|
||||
for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
|
||||
_apply_layer_skip_hook(denoiser, config, name=name)
|
||||
|
||||
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
|
||||
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
||||
# Remove the hooks after inference
|
||||
for hook_name in self._skip_layer_hook_names:
|
||||
registry.remove_hook(hook_name, recurse=True)
|
||||
|
||||
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
||||
if self.num_conditions == 1:
|
||||
tuple_indices = [0]
|
||||
input_predictions = ["pred_cond"]
|
||||
elif self.num_conditions == 2:
|
||||
tuple_indices = [0, 1]
|
||||
input_predictions = ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
|
||||
else:
|
||||
tuple_indices = [0, 1, 0]
|
||||
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], input_predictions[i])
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred_cond: torch.Tensor,
|
||||
pred_uncond: Optional[torch.Tensor] = None,
|
||||
pred_cond_skip: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if not self._is_cfg_enabled() and not self._is_slg_enabled():
|
||||
pred = pred_cond
|
||||
elif not self._is_cfg_enabled():
|
||||
shift = pred_cond - pred_cond_skip
|
||||
pred = pred_cond if self.use_original_formulation else pred_cond_skip
|
||||
pred = pred + self.skip_layer_guidance_scale * shift
|
||||
elif not self._is_slg_enabled():
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift
|
||||
else:
|
||||
shift = pred_cond - pred_uncond
|
||||
shift_skip = pred_cond - pred_cond_skip
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred, {}
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._count_prepared == 1 or self._count_prepared == 3
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_cfg_enabled():
|
||||
num_conditions += 1
|
||||
if self._is_slg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
def _is_slg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
|
||||
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step < self._step < skip_stop_step
|
||||
|
||||
is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
|
||||
|
||||
return is_within_range and not is_zero
|
||||
@@ -0,0 +1,241 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..hooks import HookRegistry
|
||||
from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
class SmoothedEnergyGuidance(BaseGuidance):
|
||||
"""
|
||||
Smoothed Energy Guidance (SEG): https://huggingface.co/papers/2408.00760
|
||||
|
||||
SEG is only supported as an experimental prototype feature for now, so the implementation may be modified
|
||||
in the future without warning or guarantee of reproducibility. This implementation assumes:
|
||||
- Generated images are square (height == width)
|
||||
- The model does not combine different modalities together (e.g., text and image latent streams are
|
||||
not combined together such as Flux)
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
seg_guidance_scale (`float`, defaults to `3.0`):
|
||||
The scale parameter for smoothed energy guidance. Anatomy and structure coherence may improve with higher
|
||||
values, but it may also lead to overexposure and saturation.
|
||||
seg_blur_sigma (`float`, defaults to `9999999.0`):
|
||||
The amount by which we blur the attention weights. Setting this value greater than 9999.0 results in
|
||||
infinite blur, which means uniform queries. Controlling it exponentially is empirically effective.
|
||||
seg_blur_threshold_inf (`float`, defaults to `9999.0`):
|
||||
The threshold above which the blur is considered infinite.
|
||||
seg_guidance_start (`float`, defaults to `0.0`):
|
||||
The fraction of the total number of denoising steps after which smoothed energy guidance starts.
|
||||
seg_guidance_stop (`float`, defaults to `1.0`):
|
||||
The fraction of the total number of denoising steps after which smoothed energy guidance stops.
|
||||
seg_guidance_layers (`int` or `List[int]`, *optional*):
|
||||
The layer indices to apply smoothed energy guidance to. Can be a single integer or a list of integers. If not
|
||||
provided, `seg_guidance_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion
|
||||
3.5 Medium.
|
||||
seg_guidance_config (`SmoothedEnergyGuidanceConfig` or `List[SmoothedEnergyGuidanceConfig]`, *optional*):
|
||||
The configuration for the smoothed energy layer guidance. Can be a single `SmoothedEnergyGuidanceConfig` or a list of
|
||||
`SmoothedEnergyGuidanceConfig`. If not provided, `seg_guidance_layers` must be provided.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float`, defaults to `0.01`):
|
||||
The fraction of the total number of denoising steps after which guidance starts.
|
||||
stop (`float`, defaults to `0.2`):
|
||||
The fraction of the total number of denoising steps after which guidance stops.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
seg_guidance_scale: float = 2.8,
|
||||
seg_blur_sigma: float = 9999999.0,
|
||||
seg_blur_threshold_inf: float = 9999.0,
|
||||
seg_guidance_start: float = 0.0,
|
||||
seg_guidance_stop: float = 1.0,
|
||||
seg_guidance_layers: Optional[Union[int, List[int]]] = None,
|
||||
seg_guidance_config: Union[SmoothedEnergyGuidanceConfig, List[SmoothedEnergyGuidanceConfig]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.seg_guidance_scale = seg_guidance_scale
|
||||
self.seg_blur_sigma = seg_blur_sigma
|
||||
self.seg_blur_threshold_inf = seg_blur_threshold_inf
|
||||
self.seg_guidance_start = seg_guidance_start
|
||||
self.seg_guidance_stop = seg_guidance_stop
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
if not (0.0 <= seg_guidance_start < 1.0):
|
||||
raise ValueError(
|
||||
f"Expected `seg_guidance_start` to be between 0.0 and 1.0, but got {seg_guidance_start}."
|
||||
)
|
||||
if not (seg_guidance_start <= seg_guidance_stop <= 1.0):
|
||||
raise ValueError(
|
||||
f"Expected `seg_guidance_stop` to be between 0.0 and 1.0, but got {seg_guidance_stop}."
|
||||
)
|
||||
|
||||
if seg_guidance_layers is None and seg_guidance_config is None:
|
||||
raise ValueError(
|
||||
"Either `seg_guidance_layers` or `seg_guidance_config` must be provided to enable Smoothed Energy Guidance."
|
||||
)
|
||||
if seg_guidance_layers is not None and seg_guidance_config is not None:
|
||||
raise ValueError("Only one of `seg_guidance_layers` or `seg_guidance_config` can be provided.")
|
||||
|
||||
if seg_guidance_layers is not None:
|
||||
if isinstance(seg_guidance_layers, int):
|
||||
seg_guidance_layers = [seg_guidance_layers]
|
||||
if not isinstance(seg_guidance_layers, list):
|
||||
raise ValueError(
|
||||
f"Expected `seg_guidance_layers` to be an int or a list of ints, but got {type(seg_guidance_layers)}."
|
||||
)
|
||||
seg_guidance_config = [SmoothedEnergyGuidanceConfig(layer, fqn="auto") for layer in seg_guidance_layers]
|
||||
|
||||
if isinstance(seg_guidance_config, SmoothedEnergyGuidanceConfig):
|
||||
seg_guidance_config = [seg_guidance_config]
|
||||
|
||||
if not isinstance(seg_guidance_config, list):
|
||||
raise ValueError(
|
||||
f"Expected `seg_guidance_config` to be a SmoothedEnergyGuidanceConfig or a list of SmoothedEnergyGuidanceConfig, but got {type(seg_guidance_config)}."
|
||||
)
|
||||
|
||||
self.seg_guidance_config = seg_guidance_config
|
||||
self._seg_layer_hook_names = [f"SmoothedEnergyGuidance_{i}" for i in range(len(self.seg_guidance_config))]
|
||||
|
||||
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
||||
if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
|
||||
for name, config in zip(self._seg_layer_hook_names, self.seg_guidance_config):
|
||||
_apply_smoothed_energy_guidance_hook(denoiser, config, self.seg_blur_sigma, name=name)
|
||||
|
||||
def cleanup_models(self, denoiser: torch.nn.Module):
|
||||
if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
||||
# Remove the hooks after inference
|
||||
for hook_name in self._seg_layer_hook_names:
|
||||
registry.remove_hook(hook_name, recurse=True)
|
||||
|
||||
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
||||
if self.num_conditions == 1:
|
||||
tuple_indices = [0]
|
||||
input_predictions = ["pred_cond"]
|
||||
elif self.num_conditions == 2:
|
||||
tuple_indices = [0, 1]
|
||||
input_predictions = ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_seg"]
|
||||
else:
|
||||
tuple_indices = [0, 1, 0]
|
||||
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], input_predictions[i])
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred_cond: torch.Tensor,
|
||||
pred_uncond: Optional[torch.Tensor] = None,
|
||||
pred_cond_seg: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if not self._is_cfg_enabled() and not self._is_seg_enabled():
|
||||
pred = pred_cond
|
||||
elif not self._is_cfg_enabled():
|
||||
shift = pred_cond - pred_cond_seg
|
||||
pred = pred_cond if self.use_original_formulation else pred_cond_seg
|
||||
pred = pred + self.seg_guidance_scale * shift
|
||||
elif not self._is_seg_enabled():
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift
|
||||
else:
|
||||
shift = pred_cond - pred_uncond
|
||||
shift_seg = pred_cond - pred_cond_seg
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift + self.seg_guidance_scale * shift_seg
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred, {}
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._count_prepared == 1 or self._count_prepared == 3
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_cfg_enabled():
|
||||
num_conditions += 1
|
||||
if self._is_seg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
def _is_seg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self.seg_guidance_start * self._num_inference_steps)
|
||||
skip_stop_step = int(self.seg_guidance_stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step < self._step < skip_stop_step
|
||||
|
||||
is_zero = math.isclose(self.seg_guidance_scale, 0.0)
|
||||
|
||||
return is_within_range and not is_zero
|
||||
@@ -0,0 +1,134 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
class TangentialClassifierFreeGuidance(BaseGuidance):
|
||||
"""
|
||||
Tangential Classifier Free Guidance (TCFG): https://huggingface.co/papers/2503.18137
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float`, defaults to `0.0`):
|
||||
The fraction of the total number of denoising steps after which guidance starts.
|
||||
stop (`float`, defaults to `1.0`):
|
||||
The fraction of the total number of denoising steps after which guidance stops.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if not self._is_tcfg_enabled():
|
||||
pred = pred_cond
|
||||
else:
|
||||
pred = normalized_guidance(pred_cond, pred_uncond, self.guidance_scale, self.use_original_formulation)
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred, {}
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._num_outputs_prepared == 1
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_tcfg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_tcfg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
|
||||
def normalized_guidance(pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guidance_scale: float, use_original_formulation: bool = False) -> torch.Tensor:
|
||||
cond_dtype = pred_cond.dtype
|
||||
preds = torch.stack([pred_cond, pred_uncond], dim=1).float()
|
||||
preds = preds.flatten(2)
|
||||
U, S, Vh = torch.linalg.svd(preds, full_matrices=False)
|
||||
Vh_modified = Vh.clone()
|
||||
Vh_modified[:, 1] = 0
|
||||
|
||||
uncond_flat = pred_uncond.reshape(pred_uncond.size(0), 1, -1).float()
|
||||
x_Vh = torch.matmul(uncond_flat, Vh.transpose(-2, -1))
|
||||
x_Vh_V = torch.matmul(x_Vh, Vh_modified)
|
||||
pred_uncond = x_Vh_V.reshape(pred_uncond.shape).to(cond_dtype)
|
||||
|
||||
pred = pred_cond if use_original_formulation else pred_uncond
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred + guidance_scale * shift
|
||||
|
||||
return pred
|
||||
@@ -5,5 +5,7 @@ if is_torch_available():
|
||||
from .faster_cache import FasterCacheConfig, apply_faster_cache
|
||||
from .group_offloading import apply_group_offloading
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
from .layer_skip import LayerSkipConfig, apply_layer_skip
|
||||
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
|
||||
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
|
||||
from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ..models.attention import FeedForward, LuminaFeedForward
|
||||
from ..models.attention_processor import Attention, MochiAttention
|
||||
|
||||
|
||||
_ATTENTION_CLASSES = (Attention, MochiAttention)
|
||||
_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward)
|
||||
|
||||
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
|
||||
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
|
||||
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")
|
||||
|
||||
_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
|
||||
{
|
||||
*_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
*_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
*_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _get_submodule_from_fqn(module: torch.nn.Module, fqn: str) -> Optional[torch.nn.Module]:
|
||||
for submodule_name, submodule in module.named_modules():
|
||||
if submodule_name == fqn:
|
||||
return submodule
|
||||
return None
|
||||
@@ -0,0 +1,271 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Type
|
||||
|
||||
from ..models.attention import BasicTransformerBlock
|
||||
from ..models.attention_processor import AttnProcessor2_0
|
||||
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
|
||||
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor, CogView4TransformerBlock
|
||||
from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
|
||||
from ..models.transformers.transformer_hunyuan_video import (
|
||||
HunyuanVideoSingleTransformerBlock,
|
||||
HunyuanVideoTokenReplaceSingleTransformerBlock,
|
||||
HunyuanVideoTokenReplaceTransformerBlock,
|
||||
HunyuanVideoTransformerBlock,
|
||||
)
|
||||
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
|
||||
from ..models.transformers.transformer_mochi import MochiTransformerBlock
|
||||
from ..models.transformers.transformer_wan import WanTransformerBlock
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttentionProcessorMetadata:
|
||||
skip_processor_output_fn: Callable[[Any], Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformerBlockMetadata:
|
||||
skip_block_output_fn: Callable[[Any], Any]
|
||||
return_hidden_states_index: int = None
|
||||
return_encoder_hidden_states_index: int = None
|
||||
|
||||
|
||||
class AttentionProcessorRegistry:
|
||||
_registry = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, model_class: Type, metadata: AttentionProcessorMetadata):
|
||||
cls._registry[model_class] = metadata
|
||||
|
||||
@classmethod
|
||||
def get(cls, model_class: Type) -> AttentionProcessorMetadata:
|
||||
if model_class not in cls._registry:
|
||||
raise ValueError(f"Model class {model_class} not registered.")
|
||||
return cls._registry[model_class]
|
||||
|
||||
|
||||
class TransformerBlockRegistry:
|
||||
_registry = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, model_class: Type, metadata: TransformerBlockMetadata):
|
||||
cls._registry[model_class] = metadata
|
||||
|
||||
@classmethod
|
||||
def get(cls, model_class: Type) -> TransformerBlockMetadata:
|
||||
if model_class not in cls._registry:
|
||||
raise ValueError(f"Model class {model_class} not registered.")
|
||||
return cls._registry[model_class]
|
||||
|
||||
|
||||
def _register_attention_processors_metadata():
|
||||
# AttnProcessor2_0
|
||||
AttentionProcessorRegistry.register(
|
||||
model_class=AttnProcessor2_0,
|
||||
metadata=AttentionProcessorMetadata(
|
||||
skip_processor_output_fn=_skip_proc_output_fn_Attention_AttnProcessor2_0,
|
||||
),
|
||||
)
|
||||
|
||||
# CogView4AttnProcessor
|
||||
AttentionProcessorRegistry.register(
|
||||
model_class=CogView4AttnProcessor,
|
||||
metadata=AttentionProcessorMetadata(
|
||||
skip_processor_output_fn=_skip_proc_output_fn_Attention_CogView4AttnProcessor,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _register_transformer_blocks_metadata():
|
||||
# BasicTransformerBlock
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=BasicTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_BasicTransformerBlock,
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=None,
|
||||
),
|
||||
)
|
||||
|
||||
# CogVideoX
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=CogVideoXBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_CogVideoXBlock,
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
|
||||
# CogView4
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=CogView4TransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_CogView4TransformerBlock,
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
|
||||
# Flux
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=FluxTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_FluxTransformerBlock,
|
||||
return_hidden_states_index=1,
|
||||
return_encoder_hidden_states_index=0,
|
||||
),
|
||||
)
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=FluxSingleTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_FluxSingleTransformerBlock,
|
||||
return_hidden_states_index=1,
|
||||
return_encoder_hidden_states_index=0,
|
||||
),
|
||||
)
|
||||
|
||||
# HunyuanVideo
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=HunyuanVideoTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTransformerBlock,
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=HunyuanVideoSingleTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoSingleTransformerBlock,
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=HunyuanVideoTokenReplaceTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock,
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=HunyuanVideoTokenReplaceSingleTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock,
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
|
||||
# LTXVideo
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=LTXVideoTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_LTXVideoTransformerBlock,
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=None,
|
||||
),
|
||||
)
|
||||
|
||||
# Mochi
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=MochiTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_MochiTransformerBlock,
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
|
||||
# Wan
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=WanTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_WanTransformerBlock,
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=None,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# fmt: off
|
||||
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
|
||||
hidden_states = kwargs.get("hidden_states", None)
|
||||
if hidden_states is None and len(args) > 0:
|
||||
hidden_states = args[0]
|
||||
return hidden_states
|
||||
|
||||
|
||||
def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
|
||||
hidden_states = kwargs.get("hidden_states", None)
|
||||
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
|
||||
if hidden_states is None and len(args) > 0:
|
||||
hidden_states = args[0]
|
||||
if encoder_hidden_states is None and len(args) > 1:
|
||||
encoder_hidden_states = args[1]
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states
|
||||
_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
|
||||
|
||||
|
||||
def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs):
|
||||
hidden_states = kwargs.get("hidden_states", None)
|
||||
if hidden_states is None and len(args) > 0:
|
||||
hidden_states = args[0]
|
||||
return hidden_states
|
||||
|
||||
|
||||
def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
|
||||
hidden_states = kwargs.get("hidden_states", None)
|
||||
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
|
||||
if hidden_states is None and len(args) > 0:
|
||||
hidden_states = args[0]
|
||||
if encoder_hidden_states is None and len(args) > 1:
|
||||
encoder_hidden_states = args[1]
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states(self, *args, **kwargs):
|
||||
hidden_states = kwargs.get("hidden_states", None)
|
||||
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
|
||||
if hidden_states is None and len(args) > 0:
|
||||
hidden_states = args[0]
|
||||
if encoder_hidden_states is None and len(args) > 1:
|
||||
encoder_hidden_states = args[1]
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
|
||||
_skip_block_output_fn_BasicTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states
|
||||
_skip_block_output_fn_CogVideoXBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
|
||||
_skip_block_output_fn_CogView4TransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
|
||||
_skip_block_output_fn_FluxTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states
|
||||
_skip_block_output_fn_FluxSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states
|
||||
_skip_block_output_fn_HunyuanVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
|
||||
_skip_block_output_fn_HunyuanVideoSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
|
||||
_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
|
||||
_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
|
||||
_skip_block_output_fn_LTXVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states
|
||||
_skip_block_output_fn_MochiTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
|
||||
_skip_block_output_fn_WanTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states
|
||||
# fmt: on
|
||||
|
||||
|
||||
_register_attention_processors_metadata()
|
||||
_register_transformer_blocks_metadata()
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -55,9 +55,9 @@ class ModuleGroup:
|
||||
parameters: Optional[List[torch.nn.Parameter]] = None,
|
||||
buffers: Optional[List[torch.Tensor]] = None,
|
||||
non_blocking: bool = False,
|
||||
stream: Optional[torch.cuda.Stream] = None,
|
||||
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
|
||||
record_stream: Optional[bool] = False,
|
||||
low_cpu_mem_usage=False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
onload_self: bool = True,
|
||||
) -> None:
|
||||
self.modules = modules
|
||||
@@ -115,8 +115,13 @@ class ModuleGroup:
|
||||
|
||||
def onload_(self):
|
||||
r"""Onloads the group of modules to the onload_device."""
|
||||
context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream)
|
||||
current_stream = torch.cuda.current_stream() if self.record_stream else None
|
||||
torch_accelerator_module = (
|
||||
getattr(torch, torch.accelerator.current_accelerator().type)
|
||||
if hasattr(torch, "accelerator")
|
||||
else torch.cuda
|
||||
)
|
||||
context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream)
|
||||
current_stream = torch_accelerator_module.current_stream() if self.record_stream else None
|
||||
|
||||
if self.stream is not None:
|
||||
# Wait for previous Host->Device transfer to complete
|
||||
@@ -162,9 +167,15 @@ class ModuleGroup:
|
||||
|
||||
def offload_(self):
|
||||
r"""Offloads the group of modules to the offload_device."""
|
||||
|
||||
torch_accelerator_module = (
|
||||
getattr(torch, torch.accelerator.current_accelerator().type)
|
||||
if hasattr(torch, "accelerator")
|
||||
else torch.cuda
|
||||
)
|
||||
if self.stream is not None:
|
||||
if not self.record_stream:
|
||||
torch.cuda.current_stream().synchronize()
|
||||
torch_accelerator_module.current_stream().synchronize()
|
||||
for group_module in self.modules:
|
||||
for param in group_module.parameters():
|
||||
param.data = self.cpu_param_dict[param]
|
||||
@@ -429,8 +440,10 @@ def apply_group_offloading(
|
||||
if use_stream:
|
||||
if torch.cuda.is_available():
|
||||
stream = torch.cuda.Stream()
|
||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
stream = torch.Stream()
|
||||
else:
|
||||
raise ValueError("Using streams for data transfer requires a CUDA device.")
|
||||
raise ValueError("Using streams for data transfer requires a CUDA device, or an Intel XPU device.")
|
||||
|
||||
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
|
||||
|
||||
@@ -468,7 +481,7 @@ def _apply_group_offloading_block_level(
|
||||
offload_device: torch.device,
|
||||
onload_device: torch.device,
|
||||
non_blocking: bool,
|
||||
stream: Optional[torch.cuda.Stream] = None,
|
||||
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
|
||||
record_stream: Optional[bool] = False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
) -> None:
|
||||
@@ -486,7 +499,7 @@ def _apply_group_offloading_block_level(
|
||||
non_blocking (`bool`):
|
||||
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
|
||||
and data transfer.
|
||||
stream (`torch.cuda.Stream`, *optional*):
|
||||
stream (`torch.cuda.Stream`or `torch.Stream`, *optional*):
|
||||
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
|
||||
for overlapping computation and data transfer.
|
||||
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
|
||||
@@ -498,6 +511,11 @@ def _apply_group_offloading_block_level(
|
||||
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
|
||||
the CPU memory is a bottleneck but may counteract the benefits of using streams.
|
||||
"""
|
||||
if stream is not None and num_blocks_per_group != 1:
|
||||
logger.warning(
|
||||
f"Using streams is only supported for num_blocks_per_group=1. Got {num_blocks_per_group=}. Setting it to 1."
|
||||
)
|
||||
num_blocks_per_group = 1
|
||||
|
||||
# Create module groups for ModuleList and Sequential blocks
|
||||
modules_with_group_offloading = set()
|
||||
@@ -521,7 +539,7 @@ def _apply_group_offloading_block_level(
|
||||
stream=stream,
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
onload_self=stream is None,
|
||||
onload_self=True,
|
||||
)
|
||||
matched_module_groups.append(group)
|
||||
for j in range(i, i + len(current_modules)):
|
||||
@@ -529,12 +547,8 @@ def _apply_group_offloading_block_level(
|
||||
|
||||
# Apply group offloading hooks to the module groups
|
||||
for i, group in enumerate(matched_module_groups):
|
||||
next_group = (
|
||||
matched_module_groups[i + 1] if i + 1 < len(matched_module_groups) and stream is not None else None
|
||||
)
|
||||
|
||||
for group_module in group.modules:
|
||||
_apply_group_offloading_hook(group_module, group, next_group)
|
||||
_apply_group_offloading_hook(group_module, group, None)
|
||||
|
||||
# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
|
||||
# when the forward pass of this module is called. This is because the top-level module is not
|
||||
@@ -560,8 +574,10 @@ def _apply_group_offloading_block_level(
|
||||
record_stream=False,
|
||||
onload_self=True,
|
||||
)
|
||||
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
|
||||
_apply_group_offloading_hook(module, unmatched_group, next_group)
|
||||
if stream is None:
|
||||
_apply_group_offloading_hook(module, unmatched_group, None)
|
||||
else:
|
||||
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
|
||||
|
||||
|
||||
def _apply_group_offloading_leaf_level(
|
||||
@@ -569,7 +585,7 @@ def _apply_group_offloading_leaf_level(
|
||||
offload_device: torch.device,
|
||||
onload_device: torch.device,
|
||||
non_blocking: bool,
|
||||
stream: Optional[torch.cuda.Stream] = None,
|
||||
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
|
||||
record_stream: Optional[bool] = False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
) -> None:
|
||||
@@ -589,7 +605,7 @@ def _apply_group_offloading_leaf_level(
|
||||
non_blocking (`bool`):
|
||||
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
|
||||
and data transfer.
|
||||
stream (`torch.cuda.Stream`, *optional*):
|
||||
stream (`torch.cuda.Stream` or `torch.Stream`, *optional*):
|
||||
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
|
||||
for overlapping computation and data transfer.
|
||||
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
|
||||
|
||||
@@ -0,0 +1,234 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import get_logger
|
||||
from ..utils.torch_utils import unwrap_module
|
||||
from ._common import (
|
||||
_ALL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
_ATTENTION_CLASSES,
|
||||
_FEEDFORWARD_CLASSES,
|
||||
_get_submodule_from_fqn,
|
||||
)
|
||||
from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
_LAYER_SKIP_HOOK = "layer_skip_hook"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayerSkipConfig:
|
||||
r"""
|
||||
Configuration for skipping internal transformer blocks when executing a transformer model.
|
||||
|
||||
Args:
|
||||
indices (`List[int]`):
|
||||
The indices of the layer to skip. This is typically the first layer in the transformer block.
|
||||
fqn (`str`, defaults to `"auto"`):
|
||||
The fully qualified name identifying the stack of transformer blocks. Typically, this is
|
||||
`transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`.
|
||||
For automatic detection, set this to `"auto"`.
|
||||
"auto" only works on DiT models. For UNet models, you must provide the correct fqn.
|
||||
skip_attention (`bool`, defaults to `True`):
|
||||
Whether to skip attention blocks.
|
||||
skip_ff (`bool`, defaults to `True`):
|
||||
Whether to skip feed-forward blocks.
|
||||
skip_attention_scores (`bool`, defaults to `False`):
|
||||
Whether to skip attention score computation in the attention blocks. This is equivalent to using `value`
|
||||
projections as the output of scaled dot product attention.
|
||||
dropout (`float`, defaults to `1.0`):
|
||||
The dropout probability for dropping the outputs of the skipped layers. By default, this is set to `1.0`,
|
||||
meaning that the outputs of the skipped layers are completely ignored. If set to `0.0`, the outputs of the
|
||||
skipped layers are fully retained, which is equivalent to not skipping any layers.
|
||||
"""
|
||||
|
||||
indices: List[int]
|
||||
fqn: str = "auto"
|
||||
skip_attention: bool = True
|
||||
skip_attention_scores: bool = False
|
||||
skip_ff: bool = True
|
||||
dropout: float = 1.0
|
||||
|
||||
def __post_init__(self):
|
||||
if not (0 <= self.dropout <= 1):
|
||||
raise ValueError(f"Expected `dropout` to be between 0.0 and 1.0, but got {self.dropout}.")
|
||||
if not math.isclose(self.dropout, 1.0) and self.skip_attention_scores:
|
||||
raise ValueError(
|
||||
"Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
|
||||
)
|
||||
|
||||
|
||||
class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode):
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if func is torch.nn.functional.scaled_dot_product_attention:
|
||||
value = kwargs.get("value", None)
|
||||
if value is None:
|
||||
value = args[2]
|
||||
return value
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
class AttentionProcessorSkipHook(ModelHook):
|
||||
def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False, dropout: float = 1.0):
|
||||
self.skip_processor_output_fn = skip_processor_output_fn
|
||||
self.skip_attention_scores = skip_attention_scores
|
||||
self.dropout = dropout
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
if self.skip_attention_scores:
|
||||
if not math.isclose(self.dropout, 1.0):
|
||||
raise ValueError(
|
||||
"Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
|
||||
)
|
||||
with AttentionScoreSkipFunctionMode():
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
else:
|
||||
if math.isclose(self.dropout, 1.0):
|
||||
output = self.skip_processor_output_fn(module, *args, **kwargs)
|
||||
else:
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
output = torch.nn.functional.dropout(output, p=self.dropout)
|
||||
return output
|
||||
|
||||
|
||||
class FeedForwardSkipHook(ModelHook):
|
||||
def __init__(self, dropout: float):
|
||||
super().__init__()
|
||||
self.dropout = dropout
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
if math.isclose(self.dropout, 1.0):
|
||||
output = kwargs.get("hidden_states", None)
|
||||
if output is None:
|
||||
output = kwargs.get("x", None)
|
||||
if output is None and len(args) > 0:
|
||||
output = args[0]
|
||||
else:
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
output = torch.nn.functional.dropout(output, p=self.dropout)
|
||||
return output
|
||||
|
||||
|
||||
class TransformerBlockSkipHook(ModelHook):
|
||||
def __init__(self, dropout: float):
|
||||
super().__init__()
|
||||
self.dropout = dropout
|
||||
|
||||
def initialize_hook(self, module):
|
||||
self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
|
||||
return module
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
if math.isclose(self.dropout, 1.0):
|
||||
output = self._metadata.skip_block_output_fn(module, *args, **kwargs)
|
||||
else:
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
output = torch.nn.functional.dropout(output, p=self.dropout)
|
||||
return output
|
||||
|
||||
def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None:
|
||||
r"""
|
||||
Apply layer skipping to internal layers of a transformer.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The transformer model to which the layer skip hook should be applied.
|
||||
config (`LayerSkipConfig`):
|
||||
The configuration for the layer skip hook.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from diffusers import apply_layer_skip_hook, CogVideoXTransformer3DModel, LayerSkipConfig
|
||||
>>> transformer = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
|
||||
>>> config = LayerSkipConfig(layer_index=[10, 20], fqn="transformer_blocks")
|
||||
>>> apply_layer_skip_hook(transformer, config)
|
||||
```
|
||||
"""
|
||||
_apply_layer_skip_hook(module, config)
|
||||
|
||||
|
||||
def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, name: Optional[str] = None) -> None:
|
||||
name = name or _LAYER_SKIP_HOOK
|
||||
|
||||
if config.skip_attention and config.skip_attention_scores:
|
||||
raise ValueError("Cannot set both `skip_attention` and `skip_attention_scores` to True. Please choose one.")
|
||||
if not math.isclose(config.dropout, 1.0) and config.skip_attention_scores:
|
||||
raise ValueError("Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0.")
|
||||
|
||||
if config.fqn == "auto":
|
||||
for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
|
||||
if hasattr(module, identifier):
|
||||
config.fqn = identifier
|
||||
break
|
||||
else:
|
||||
raise ValueError(
|
||||
"Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid "
|
||||
"`fqn` (fully qualified name) that identifies a stack of transformer blocks."
|
||||
)
|
||||
|
||||
transformer_blocks = _get_submodule_from_fqn(module, config.fqn)
|
||||
if transformer_blocks is None or not isinstance(transformer_blocks, torch.nn.ModuleList):
|
||||
raise ValueError(
|
||||
f"Could not find {config.fqn} in the provided module, or configured `fqn` (fully qualified name) does not identify "
|
||||
f"a `torch.nn.ModuleList`. Please provide a valid `fqn` that identifies a stack of transformer blocks."
|
||||
)
|
||||
if len(config.indices) == 0:
|
||||
raise ValueError("Layer index list is empty. Please provide a non-empty list of layer indices to skip.")
|
||||
|
||||
blocks_found = False
|
||||
for i, block in enumerate(transformer_blocks):
|
||||
if i not in config.indices:
|
||||
continue
|
||||
|
||||
blocks_found = True
|
||||
|
||||
if config.skip_attention and config.skip_ff:
|
||||
logger.debug(f"Applying TransformerBlockSkipHook to '{config.fqn}.{i}'")
|
||||
registry = HookRegistry.check_if_exists_or_initialize(block)
|
||||
hook = TransformerBlockSkipHook(config.dropout)
|
||||
registry.register_hook(hook, name)
|
||||
|
||||
elif config.skip_attention or config.skip_attention_scores:
|
||||
for submodule_name, submodule in block.named_modules():
|
||||
if isinstance(submodule, _ATTENTION_CLASSES) and not submodule.is_cross_attention:
|
||||
logger.debug(f"Applying AttentionProcessorSkipHook to '{config.fqn}.{i}.{submodule_name}'")
|
||||
output_fn = AttentionProcessorRegistry.get(submodule.processor.__class__).skip_processor_output_fn
|
||||
registry = HookRegistry.check_if_exists_or_initialize(submodule)
|
||||
hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores, config.dropout)
|
||||
registry.register_hook(hook, name)
|
||||
|
||||
if config.skip_ff:
|
||||
for submodule_name, submodule in block.named_modules():
|
||||
if isinstance(submodule, _FEEDFORWARD_CLASSES):
|
||||
logger.debug(f"Applying FeedForwardSkipHook to '{config.fqn}.{i}.{submodule_name}'")
|
||||
registry = HookRegistry.check_if_exists_or_initialize(submodule)
|
||||
hook = FeedForwardSkipHook(config.dropout)
|
||||
registry.register_hook(hook, name)
|
||||
|
||||
if not blocks_found:
|
||||
raise ValueError(
|
||||
f"Could not find any transformer blocks matching the provided indices {config.indices} and "
|
||||
f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness."
|
||||
)
|
||||
@@ -0,0 +1,158 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..utils import get_logger
|
||||
from ._common import _ATTENTION_CLASSES, _get_submodule_from_fqn
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
_SMOOTHED_ENERGY_GUIDANCE_HOOK = "smoothed_energy_guidance_hook"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SmoothedEnergyGuidanceConfig:
|
||||
r"""
|
||||
Configuration for skipping internal transformer blocks when executing a transformer model.
|
||||
|
||||
Args:
|
||||
indices (`List[int]`):
|
||||
The indices of the layer to skip. This is typically the first layer in the transformer block.
|
||||
fqn (`str`, defaults to `"auto"`):
|
||||
The fully qualified name identifying the stack of transformer blocks. Typically, this is
|
||||
`transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`.
|
||||
For automatic detection, set this to `"auto"`.
|
||||
"auto" only works on DiT models. For UNet models, you must provide the correct fqn.
|
||||
_query_proj_identifiers (`List[str]`, defaults to `None`):
|
||||
The identifiers for the query projection layers. Typically, these are `to_q`, `query`, or `q_proj`.
|
||||
If `None`, `to_q` is used by default.
|
||||
"""
|
||||
|
||||
indices: List[int]
|
||||
fqn: str = "auto"
|
||||
_query_proj_identifiers: List[str] = None
|
||||
|
||||
|
||||
class SmoothedEnergyGuidanceHook(ModelHook):
|
||||
def __init__(self, blur_sigma: float = 1.0, blur_threshold_inf: float = 9999.9) -> None:
|
||||
super().__init__()
|
||||
self.blur_sigma = blur_sigma
|
||||
self.blur_threshold_inf = blur_threshold_inf
|
||||
|
||||
def post_forward(self, module: torch.nn.Module, output: torch.Tensor) -> torch.Tensor:
|
||||
# Copied from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L172C31-L172C102
|
||||
kernel_size = math.ceil(6 * self.blur_sigma) + 1 - math.ceil(6 * self.blur_sigma) % 2
|
||||
smoothed_output = _gaussian_blur_2d(output, kernel_size, self.blur_sigma, self.blur_threshold_inf)
|
||||
return smoothed_output
|
||||
|
||||
|
||||
def _apply_smoothed_energy_guidance_hook(module: torch.nn.Module, config: SmoothedEnergyGuidanceConfig, blur_sigma: float, name: Optional[str] = None) -> None:
|
||||
name = name or _SMOOTHED_ENERGY_GUIDANCE_HOOK
|
||||
|
||||
if config.fqn == "auto":
|
||||
for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
|
||||
if hasattr(module, identifier):
|
||||
config.fqn = identifier
|
||||
break
|
||||
else:
|
||||
raise ValueError(
|
||||
"Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid "
|
||||
"`fqn` (fully qualified name) that identifies a stack of transformer blocks."
|
||||
)
|
||||
|
||||
if config._query_proj_identifiers is None:
|
||||
config._query_proj_identifiers = ["to_q"]
|
||||
|
||||
transformer_blocks = _get_submodule_from_fqn(module, config.fqn)
|
||||
blocks_found = False
|
||||
for i, block in enumerate(transformer_blocks):
|
||||
if i not in config.indices:
|
||||
continue
|
||||
|
||||
blocks_found = True
|
||||
|
||||
for submodule_name, submodule in block.named_modules():
|
||||
if not isinstance(submodule, _ATTENTION_CLASSES) or submodule.is_cross_attention:
|
||||
continue
|
||||
for identifier in config._query_proj_identifiers:
|
||||
query_proj = getattr(submodule, identifier, None)
|
||||
if query_proj is None or not isinstance(query_proj, torch.nn.Linear):
|
||||
continue
|
||||
logger.debug(
|
||||
f"Registering smoothed energy guidance hook on {config.fqn}.{i}.{submodule_name}.{identifier}"
|
||||
)
|
||||
registry = HookRegistry.check_if_exists_or_initialize(query_proj)
|
||||
hook = SmoothedEnergyGuidanceHook(blur_sigma)
|
||||
registry.register_hook(hook, name)
|
||||
|
||||
if not blocks_found:
|
||||
raise ValueError(
|
||||
f"Could not find any transformer blocks matching the provided indices {config.indices} and "
|
||||
f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness."
|
||||
)
|
||||
|
||||
|
||||
# Modified from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L71
|
||||
def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma_threshold_inf: float) -> torch.Tensor:
|
||||
"""
|
||||
This implementation assumes that the input query is for visual (image/videos) tokens to apply the 2D gaussian
|
||||
blur. However, some models use joint text-visual token attention for which this may not be suitable. Additionally,
|
||||
this implementation also assumes that the visual tokens come from a square image/video. In practice, despite
|
||||
these assumptions, applying the 2D square gaussian blur on the query projections generates reasonable results
|
||||
for Smoothed Energy Guidance.
|
||||
|
||||
SEG is only supported as an experimental prototype feature for now, so the implementation may be modified
|
||||
in the future without warning or guarantee of reproducibility.
|
||||
"""
|
||||
assert query.ndim == 3
|
||||
|
||||
is_inf = sigma > sigma_threshold_inf
|
||||
batch_size, seq_len, embed_dim = query.shape
|
||||
|
||||
seq_len_sqrt = int(math.sqrt(seq_len))
|
||||
num_square_tokens = seq_len_sqrt * seq_len_sqrt
|
||||
query_slice = query[:, :num_square_tokens, :]
|
||||
query_slice = query_slice.permute(0, 2, 1)
|
||||
query_slice = query_slice.reshape(batch_size, embed_dim, seq_len_sqrt, seq_len_sqrt)
|
||||
|
||||
if is_inf:
|
||||
kernel_size = min(kernel_size, seq_len_sqrt - (seq_len_sqrt % 2 - 1))
|
||||
kernel_size_half = (kernel_size - 1) / 2
|
||||
|
||||
x = torch.linspace(-kernel_size_half, kernel_size_half, steps=kernel_size)
|
||||
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
|
||||
kernel1d = pdf / pdf.sum()
|
||||
kernel1d = kernel1d.to(query)
|
||||
kernel2d = torch.matmul(kernel1d[:, None], kernel1d[None, :])
|
||||
kernel2d = kernel2d.expand(embed_dim, 1, kernel2d.shape[0], kernel2d.shape[1])
|
||||
|
||||
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
|
||||
query_slice = F.pad(query_slice, padding, mode="reflect")
|
||||
query_slice = F.conv2d(query_slice, kernel2d, groups=embed_dim)
|
||||
else:
|
||||
query_slice[:] = query_slice.mean(dim=(-2, -1), keepdim=True)
|
||||
|
||||
query_slice = query_slice.reshape(batch_size, embed_dim, num_square_tokens)
|
||||
query_slice = query_slice.permute(0, 2, 1)
|
||||
query[:, :num_square_tokens, :] = query_slice.clone()
|
||||
|
||||
return query
|
||||
@@ -116,6 +116,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
vae_scale_factor: int = 8,
|
||||
vae_latent_channels: int = 4,
|
||||
resample: str = "lanczos",
|
||||
reducing_gap: int = None,
|
||||
do_normalize: bool = True,
|
||||
do_binarize: bool = False,
|
||||
do_convert_rgb: bool = False,
|
||||
@@ -498,7 +499,11 @@ class VaeImageProcessor(ConfigMixin):
|
||||
raise ValueError(f"Only PIL image input is supported for resize_mode {resize_mode}")
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
if resize_mode == "default":
|
||||
image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
|
||||
image = image.resize(
|
||||
(width, height),
|
||||
resample=PIL_INTERPOLATION[self.config.resample],
|
||||
reducing_gap=self.config.reducing_gap,
|
||||
)
|
||||
elif resize_mode == "fill":
|
||||
image = self._resize_and_fill(image, width, height)
|
||||
elif resize_mode == "crop":
|
||||
|
||||
@@ -77,12 +77,14 @@ if is_torch_available():
|
||||
"SanaLoraLoaderMixin",
|
||||
"Lumina2LoraLoaderMixin",
|
||||
"WanLoraLoaderMixin",
|
||||
"HiDreamImageLoraLoaderMixin",
|
||||
]
|
||||
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
|
||||
_import_structure["ip_adapter"] = [
|
||||
"IPAdapterMixin",
|
||||
"FluxIPAdapterMixin",
|
||||
"SD3IPAdapterMixin",
|
||||
"ModularIPAdapterMixin",
|
||||
]
|
||||
|
||||
_import_structure["peft"] = ["PeftAdapterMixin"]
|
||||
@@ -100,6 +102,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .ip_adapter import (
|
||||
FluxIPAdapterMixin,
|
||||
IPAdapterMixin,
|
||||
ModularIPAdapterMixin,
|
||||
SD3IPAdapterMixin,
|
||||
)
|
||||
from .lora_pipeline import (
|
||||
@@ -108,6 +111,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
CogVideoXLoraLoaderMixin,
|
||||
CogView4LoraLoaderMixin,
|
||||
FluxLoraLoaderMixin,
|
||||
HiDreamImageLoraLoaderMixin,
|
||||
HunyuanVideoLoraLoaderMixin,
|
||||
LoraLoaderMixin,
|
||||
LTXVideoLoraLoaderMixin,
|
||||
|
||||
@@ -356,6 +356,265 @@ class IPAdapterMixin:
|
||||
)
|
||||
self.unet.set_attn_processor(attn_procs)
|
||||
|
||||
class ModularIPAdapterMixin:
|
||||
"""Mixin for handling IP Adapters."""
|
||||
|
||||
@validate_hf_hub_args
|
||||
def load_ip_adapter(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
|
||||
subfolder: Union[str, List[str]],
|
||||
weight_name: Union[str, List[str]],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
subfolder (`str` or `List[str]`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
|
||||
list is passed, it should have the same length as `weight_name`.
|
||||
weight_name (`str` or `List[str]`):
|
||||
The name of the weight file to load. If a list is passed, it should have the same length as
|
||||
`subfolder`.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
"""
|
||||
|
||||
# handle the list inputs for multiple IP Adapters
|
||||
if not isinstance(weight_name, list):
|
||||
weight_name = [weight_name]
|
||||
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, list):
|
||||
pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
|
||||
if len(pretrained_model_name_or_path_or_dict) == 1:
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
|
||||
|
||||
if not isinstance(subfolder, list):
|
||||
subfolder = [subfolder]
|
||||
if len(subfolder) == 1:
|
||||
subfolder = subfolder * len(weight_name)
|
||||
|
||||
if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
|
||||
raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
|
||||
|
||||
if len(weight_name) != len(subfolder):
|
||||
raise ValueError("`weight_name` and `subfolder` must have the same length.")
|
||||
|
||||
# Load the main state dict first.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
state_dicts = []
|
||||
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
|
||||
pretrained_model_name_or_path_or_dict, weight_name, subfolder
|
||||
):
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
if weight_name.endswith(".safetensors"):
|
||||
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
||||
with safe_open(model_file, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
if key.startswith("image_proj."):
|
||||
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
|
||||
elif key.startswith("ip_adapter."):
|
||||
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
|
||||
else:
|
||||
state_dict = load_state_dict(model_file)
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
if "image_proj" not in keys and "ip_adapter" not in keys:
|
||||
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
|
||||
|
||||
state_dicts.append(state_dict)
|
||||
|
||||
# create feature extractor if it has not been registered to the pipeline yet
|
||||
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
|
||||
# FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224
|
||||
default_clip_size = 224
|
||||
clip_image_size = (
|
||||
self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size
|
||||
)
|
||||
feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size)
|
||||
|
||||
unet_name = getattr(self, "unet_name", "unet")
|
||||
unet = getattr(self, unet_name)
|
||||
unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||
|
||||
extra_loras = unet._load_ip_adapter_loras(state_dicts)
|
||||
if extra_loras != {}:
|
||||
if not USE_PEFT_BACKEND:
|
||||
logger.warning("PEFT backend is required to load these weights.")
|
||||
else:
|
||||
# apply the IP Adapter Face ID LoRA weights
|
||||
peft_config = getattr(unet, "peft_config", {})
|
||||
for k, lora in extra_loras.items():
|
||||
if f"faceid_{k}" not in peft_config:
|
||||
self.load_lora_weights(lora, adapter_name=f"faceid_{k}")
|
||||
self.set_adapters([f"faceid_{k}"], adapter_weights=[1.0])
|
||||
|
||||
def set_ip_adapter_scale(self, scale):
|
||||
"""
|
||||
Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
|
||||
granular control over each IP-Adapter behavior. A config can be a float or a dictionary.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
# To use original IP-Adapter
|
||||
scale = 1.0
|
||||
pipeline.set_ip_adapter_scale(scale)
|
||||
|
||||
# To use style block only
|
||||
scale = {
|
||||
"up": {"block_0": [0.0, 1.0, 0.0]},
|
||||
}
|
||||
pipeline.set_ip_adapter_scale(scale)
|
||||
|
||||
# To use style+layout blocks
|
||||
scale = {
|
||||
"down": {"block_2": [0.0, 1.0]},
|
||||
"up": {"block_0": [0.0, 1.0, 0.0]},
|
||||
}
|
||||
pipeline.set_ip_adapter_scale(scale)
|
||||
|
||||
# To use style and layout from 2 reference images
|
||||
scales = [{"down": {"block_2": [0.0, 1.0]}}, {"up": {"block_0": [0.0, 1.0, 0.0]}}]
|
||||
pipeline.set_ip_adapter_scale(scales)
|
||||
```
|
||||
"""
|
||||
unet_name = getattr(self, "unet_name", "unet")
|
||||
unet = getattr(self, unet_name)
|
||||
if not isinstance(scale, list):
|
||||
scale = [scale]
|
||||
scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0)
|
||||
|
||||
for attn_name, attn_processor in unet.attn_processors.items():
|
||||
if isinstance(
|
||||
attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
|
||||
):
|
||||
if len(scale_configs) != len(attn_processor.scale):
|
||||
raise ValueError(
|
||||
f"Cannot assign {len(scale_configs)} scale_configs to "
|
||||
f"{len(attn_processor.scale)} IP-Adapter."
|
||||
)
|
||||
elif len(scale_configs) == 1:
|
||||
scale_configs = scale_configs * len(attn_processor.scale)
|
||||
for i, scale_config in enumerate(scale_configs):
|
||||
if isinstance(scale_config, dict):
|
||||
for k, s in scale_config.items():
|
||||
if attn_name.startswith(k):
|
||||
attn_processor.scale[i] = s
|
||||
else:
|
||||
attn_processor.scale[i] = scale_config
|
||||
|
||||
def unload_ip_adapter(self):
|
||||
"""
|
||||
Unloads the IP Adapter weights
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
|
||||
>>> pipeline.unload_ip_adapter()
|
||||
>>> ...
|
||||
```
|
||||
"""
|
||||
|
||||
# remove hidden encoder
|
||||
if self.unet is None:
|
||||
return
|
||||
|
||||
self.unet.encoder_hid_proj = None
|
||||
self.unet.config.encoder_hid_dim_type = None
|
||||
|
||||
# Kolors: restore `encoder_hid_proj` with `text_encoder_hid_proj`
|
||||
if hasattr(self.unet, "text_encoder_hid_proj") and self.unet.text_encoder_hid_proj is not None:
|
||||
self.unet.encoder_hid_proj = self.unet.text_encoder_hid_proj
|
||||
self.unet.text_encoder_hid_proj = None
|
||||
self.unet.config.encoder_hid_dim_type = "text_proj"
|
||||
|
||||
# restore original Unet attention processors layers
|
||||
attn_procs = {}
|
||||
for name, value in self.unet.attn_processors.items():
|
||||
attn_processor_class = (
|
||||
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor()
|
||||
)
|
||||
attn_procs[name] = (
|
||||
attn_processor_class
|
||||
if isinstance(
|
||||
value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
|
||||
)
|
||||
else value.__class__()
|
||||
)
|
||||
self.unet.set_attn_processor(attn_procs)
|
||||
|
||||
|
||||
class FluxIPAdapterMixin:
|
||||
"""Mixin for handling Flux IP Adapters."""
|
||||
|
||||
@@ -441,7 +441,7 @@ def _func_optionally_disable_offloading(_pipeline):
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
|
||||
if _pipeline is not None and _pipeline.hf_device_map is None:
|
||||
if _pipeline is not None and hasattr(_pipeline, "hf_device_map") and _pipeline.hf_device_map is None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
||||
if not is_model_cpu_offload:
|
||||
@@ -491,6 +491,7 @@ class LoraBaseMixin:
|
||||
tuple:
|
||||
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
||||
"""
|
||||
|
||||
return _func_optionally_disable_offloading(_pipeline=_pipeline)
|
||||
|
||||
@classmethod
|
||||
@@ -713,8 +714,10 @@ class LoraBaseMixin:
|
||||
# Decompose weights into weights for denoiser and text encoders.
|
||||
_component_adapter_weights = {}
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component)
|
||||
|
||||
model = getattr(self, component, None)
|
||||
if model is None:
|
||||
logger.warning(f"Model {component} not found in pipeline.")
|
||||
continue
|
||||
for adapter_name, weights in zip(adapter_names, adapter_weights):
|
||||
if isinstance(weights, dict):
|
||||
component_adapter_weights = weights.pop(component, None)
|
||||
|
||||
@@ -433,7 +433,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
||||
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
|
||||
if not is_sparse:
|
||||
# down_weight is copied to each split
|
||||
ait_sd.update({k: down_weight for k in ait_down_keys})
|
||||
ait_sd.update(dict.fromkeys(ait_down_keys, down_weight))
|
||||
|
||||
# up_weight is split to each split
|
||||
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
|
||||
@@ -923,7 +923,7 @@ def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
|
||||
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
|
||||
|
||||
# down_weight is copied to each split
|
||||
ait_sd.update({k: down_weight for k in ait_down_keys})
|
||||
ait_sd.update(dict.fromkeys(ait_down_keys, down_weight))
|
||||
|
||||
# up_weight is split to each split
|
||||
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
|
||||
|
||||
@@ -91,18 +91,19 @@ def _maybe_dequantize_weight_for_expanded_lora(model, module):
|
||||
)
|
||||
|
||||
weight_on_cpu = False
|
||||
if not module.weight.is_cuda:
|
||||
if module.weight.device.type == "cpu":
|
||||
weight_on_cpu = True
|
||||
|
||||
device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
|
||||
if is_bnb_4bit_quantized:
|
||||
module_weight = dequantize_bnb_weight(
|
||||
module.weight.cuda() if weight_on_cpu else module.weight,
|
||||
module.weight.to(device) if weight_on_cpu else module.weight,
|
||||
state=module.weight.quant_state,
|
||||
dtype=model.dtype,
|
||||
).data
|
||||
elif is_gguf_quantized:
|
||||
module_weight = dequantize_gguf_tensor(
|
||||
module.weight.cuda() if weight_on_cpu else module.weight,
|
||||
module.weight.to(device) if weight_on_cpu else module.weight,
|
||||
)
|
||||
module_weight = module_weight.to(model.dtype)
|
||||
else:
|
||||
@@ -635,7 +636,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
state_dict, network_alphas = self.lora_state_dict(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
unet_config=self.unet.config,
|
||||
unet_config=self.unet.config if hasattr(self, "unet") else None,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -643,37 +644,40 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_unet(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
unet=self.unet,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
self.load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder,
|
||||
prefix=self.text_encoder_name,
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
self.load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder_2,
|
||||
prefix=f"{self.text_encoder_name}_2",
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
if hasattr(self, "unet"):
|
||||
self.load_lora_into_unet(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
unet=self.unet,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
if hasattr(self, "text_encoder"):
|
||||
self.load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder,
|
||||
prefix=self.text_encoder_name,
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
self.load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder_2,
|
||||
prefix=f"{self.text_encoder_name}_2",
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
@@ -5360,6 +5364,325 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
|
||||
r"""
|
||||
Load LoRA layers into [`HiDreamImageTransformer2DModel`]. Specific to [`HiDreamImagePipeline`].
|
||||
"""
|
||||
|
||||
_lora_loadable_modules = ["transformer"]
|
||||
transformer_name = TRANSFORMER_NAME
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Return state dict for lora weights and the network alphas.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
||||
|
||||
This function is experimental and might change in the future.
|
||||
|
||||
</Tip>
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||
|
||||
"""
|
||||
# Load the main state dict first which has the LoRA layers for either of
|
||||
# transformer and text encoder or both.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
weight_name = kwargs.pop("weight_name", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
state_dict = _fetch_state_dict(
|
||||
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
||||
weight_name=weight_name,
|
||||
use_safetensors=use_safetensors,
|
||||
local_files_only=local_files_only,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
allow_pickle=allow_pickle,
|
||||
)
|
||||
|
||||
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
||||
if is_dora_scale_present:
|
||||
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
return state_dict
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
adapter_name: Optional[str] = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
||||
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
|
||||
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
|
||||
dict is loaded into `self.transformer`.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# if a dict is passed, copy it instead of modifying it inplace
|
||||
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
||||
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HiDreamImageTransformer2DModel
|
||||
def load_lora_into_transformer(
|
||||
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
||||
|
||||
Parameters:
|
||||
state_dict (`dict`):
|
||||
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
||||
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
||||
encoder lora layers.
|
||||
transformer (`HiDreamImageTransformer2DModel`):
|
||||
The Transformer model to load the LoRA layers into.
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# Load the layers corresponding to transformer.
|
||||
logger.info(f"Loading {cls.transformer_name}.")
|
||||
transformer.load_lora_adapter(
|
||||
state_dict,
|
||||
network_alphas=None,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
):
|
||||
r"""
|
||||
Save the LoRA parameters corresponding to the UNet and text encoder.
|
||||
|
||||
Arguments:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
||||
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
||||
State dict of the LoRA layers corresponding to the `transformer`.
|
||||
is_main_process (`bool`, *optional*, defaults to `True`):
|
||||
Whether the process calling this is the main process or not. Useful during distributed training and you
|
||||
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
||||
process to avoid race conditions.
|
||||
save_function (`Callable`):
|
||||
The function to use to save the state dictionary. Useful during distributed training when you need to
|
||||
replace `torch.save` with another method. Can be configured with the environment variable
|
||||
`DIFFUSERS_SAVE_MODE`.
|
||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
||||
"""
|
||||
state_dict = {}
|
||||
|
||||
if not transformer_lora_layers:
|
||||
raise ValueError("You must pass `transformer_lora_layers`.")
|
||||
|
||||
if transformer_lora_layers:
|
||||
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
||||
|
||||
# Save the model
|
||||
cls.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
save_directory=save_directory,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental API.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
||||
lora_scale (`float`, defaults to 1.0):
|
||||
Controls how much to influence the outputs with the LoRA parameters.
|
||||
safe_fusing (`bool`, defaults to `False`):
|
||||
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
||||
pipeline.fuse_lora(lora_scale=0.7)
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
Reverses the effect of
|
||||
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental API.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
"""
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
|
||||
|
||||
@@ -56,6 +56,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
|
||||
"Lumina2Transformer2DModel": lambda model_cls, weights: weights,
|
||||
"WanTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"CogView4Transformer2DModel": lambda model_cls, weights: weights,
|
||||
"HiDreamImageTransformer2DModel": lambda model_cls, weights: weights,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -408,6 +408,7 @@ class UNet2DConditionLoadersMixin:
|
||||
tuple:
|
||||
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
||||
"""
|
||||
|
||||
return _func_optionally_disable_offloading(_pipeline=_pipeline)
|
||||
|
||||
def save_attn_procs(
|
||||
|
||||
@@ -20,12 +20,12 @@ import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...models.attention_processor import AttentionProcessor
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention_processor import AttentionProcessor
|
||||
from ..controlnets.controlnet import ControlNetConditioningEmbedding, zero_module
|
||||
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
|
||||
|
||||
|
||||
|
||||
@@ -4,9 +4,9 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...models.controlnets.controlnet import ControlNetModel, ControlNetOutput
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...utils import logging
|
||||
from ..controlnets.controlnet import ControlNetModel, ControlNetOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@@ -4,10 +4,10 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...models.controlnets.controlnet import ControlNetOutput
|
||||
from ...models.controlnets.controlnet_union import ControlNetUnionModel
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...utils import logging
|
||||
from ..controlnets.controlnet import ControlNetOutput
|
||||
from ..controlnets.controlnet_union import ControlNetUnionModel
|
||||
from ..modeling_utils import ModelMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@@ -286,7 +286,7 @@ class KDownsample2D(nn.Module):
|
||||
|
||||
|
||||
class CogVideoXDownsample3D(nn.Module):
|
||||
# Todo: Wait for paper relase.
|
||||
# Todo: Wait for paper release.
|
||||
r"""
|
||||
A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ import importlib
|
||||
import inspect
|
||||
import os
|
||||
from array import array
|
||||
from collections import OrderedDict, defaultdict
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
from zipfile import is_zipfile
|
||||
@@ -38,7 +38,6 @@ from ..utils import (
|
||||
_get_model_file,
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_accelerator_device,
|
||||
is_gguf_available,
|
||||
is_torch_available,
|
||||
is_torch_version,
|
||||
@@ -305,51 +304,6 @@ def load_model_dict_into_meta(
|
||||
return offload_index, state_dict_index
|
||||
|
||||
|
||||
# Taken from
|
||||
# https://github.com/huggingface/transformers/blob/6daa3eeba582facb57cd71db8efb66998b12942f/src/transformers/modeling_utils.py#L5852C1-L5861C26
|
||||
def _expand_device_map(device_map, param_names):
|
||||
new_device_map = {}
|
||||
for module, device in device_map.items():
|
||||
new_device_map.update(
|
||||
{p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
|
||||
)
|
||||
return new_device_map
|
||||
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/6daa3eeba582facb57cd71db8efb66998b12942f/src/transformers/modeling_utils.py#L5874
|
||||
# We don't incorporate the `tp_plan` stuff as we don't support it yet.
|
||||
def _caching_allocator_warmup(model, device_map: Dict, factor=2) -> Dict:
|
||||
# Remove disk, cpu and meta devices, and cast to proper torch.device
|
||||
accelerator_device_map = {
|
||||
param: torch.device(device) for param, device in device_map.items() if is_accelerator_device(device)
|
||||
}
|
||||
if not len(accelerator_device_map):
|
||||
return
|
||||
|
||||
total_byte_count = defaultdict(lambda: 0)
|
||||
for param_name, device in accelerator_device_map.items():
|
||||
param = model.get_parameter_or_buffer(param_name)
|
||||
# The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
|
||||
param_byte_count = param.numel() * param.element_size()
|
||||
total_byte_count[device] += param_byte_count
|
||||
|
||||
# This will kick off the caching allocator to avoid having to Malloc afterwards
|
||||
for device, byte_count in total_byte_count.items():
|
||||
if device.type == "cuda":
|
||||
index = device.index if device.index is not None else torch.cuda.current_device()
|
||||
device_memory = torch.cuda.mem_get_info(index)[0]
|
||||
# Allow up to (max device memory - 1.2 GiB) in resource-constrained hardware configurations. Trying to reserve more
|
||||
# than that amount might sometimes lead to unecesary cuda OOM, if the last parameter to be loaded on the device is large,
|
||||
# and the remaining reserved memory portion is smaller than the param size -> torch will then try to fully re-allocate all
|
||||
# the param size, instead of using the remaining reserved part, and allocating only the difference, which can lead
|
||||
# to OOM. See https://github.com/huggingface/transformers/issues/37436#issuecomment-2808982161 for more details.
|
||||
# Note that we use an absolute value instead of device proportion here, as a 8GiB device could still allocate too much
|
||||
# if using e.g. 90% of device size, while a 140GiB device would allocate too little
|
||||
byte_count = min(byte_count, max(0, int(device_memory - 1.2 * 1024**3)))
|
||||
# Allocate memory
|
||||
_ = torch.empty(byte_count // factor, dtype=torch.float16, device=device, requires_grad=False)
|
||||
|
||||
|
||||
def _load_state_dict_into_model(
|
||||
model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False
|
||||
) -> List[str]:
|
||||
|
||||
@@ -63,9 +63,7 @@ from ..utils.hub_utils import (
|
||||
populate_model_card,
|
||||
)
|
||||
from .model_loading_utils import (
|
||||
_caching_allocator_warmup,
|
||||
_determine_device_map,
|
||||
_expand_device_map,
|
||||
_fetch_index_file,
|
||||
_fetch_index_file_legacy,
|
||||
_load_state_dict_into_model,
|
||||
@@ -1376,24 +1374,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
else:
|
||||
return super().float(*args)
|
||||
|
||||
# Taken from `transformers`.
|
||||
# https://github.com/huggingface/transformers/blob/6daa3eeba582facb57cd71db8efb66998b12942f/src/transformers/modeling_utils.py#L5351C5-L5365C81
|
||||
def get_parameter_or_buffer(self, target: str):
|
||||
"""
|
||||
Return the parameter or buffer given by `target` if it exists, otherwise throw an error. This combines
|
||||
`get_parameter()` and `get_buffer()` in a single handy function. Note that it only work if `target` is a leaf
|
||||
of the model.
|
||||
"""
|
||||
try:
|
||||
return self.get_parameter(target)
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
return self.get_buffer(target)
|
||||
except AttributeError:
|
||||
pass
|
||||
raise AttributeError(f"`{target}` is neither a parameter nor a buffer.")
|
||||
|
||||
@classmethod
|
||||
def _load_pretrained_model(
|
||||
cls,
|
||||
@@ -1430,11 +1410,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
assign_to_params_buffers = None
|
||||
error_msgs = []
|
||||
|
||||
# Optionally, warmup cuda to load the weights much faster on devices
|
||||
if device_map is not None:
|
||||
expanded_device_map = _expand_device_map(device_map, expected_keys)
|
||||
_caching_allocator_warmup(model, expanded_device_map, factor=2 if hf_quantizer is None else 4)
|
||||
|
||||
# Deal with offload
|
||||
if device_map is not None and "disk" in device_map.values():
|
||||
if offload_folder is None:
|
||||
|
||||
@@ -18,10 +18,9 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models.embeddings import PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
|
||||
from ..attention import BasicTransformerBlock
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import PatchEmbed
|
||||
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormSingle
|
||||
|
||||
@@ -21,16 +21,12 @@ import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models.attention import FeedForward
|
||||
from ...models.attention_processor import (
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
StableAudioAttnProcessor2_0,
|
||||
)
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.transformers.transformer_2d import Transformer2DModelOutput
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention, AttentionProcessor, StableAudioAttnProcessor2_0
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..transformers.transformer_2d import Transformer2DModelOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -19,18 +19,13 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models.attention import FeedForward
|
||||
from ...models.attention_processor import (
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
CogVideoXAttnProcessor2_0,
|
||||
)
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.normalization import AdaLayerNormContinuous
|
||||
from ...utils import logging
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention, AttentionProcessor, CogVideoXAttnProcessor2_0
|
||||
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..normalization import CogView3PlusAdaLayerNormZeroTextImage
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous, CogView3PlusAdaLayerNormZeroTextImage
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -73,8 +73,9 @@ class CogView4AdaLayerNormZero(nn.Module):
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
norm_hidden_states = self.norm(hidden_states)
|
||||
norm_encoder_hidden_states = self.norm_context(encoder_hidden_states)
|
||||
dtype = hidden_states.dtype
|
||||
norm_hidden_states = self.norm(hidden_states).to(dtype=dtype)
|
||||
norm_encoder_hidden_states = self.norm_context(encoder_hidden_states).to(dtype=dtype)
|
||||
|
||||
emb = self.linear(temb)
|
||||
(
|
||||
@@ -111,8 +112,11 @@ class CogView4AdaLayerNormZero(nn.Module):
|
||||
|
||||
class CogView4AttnProcessor:
|
||||
"""
|
||||
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
||||
Processor for implementing scaled dot-product attention for the CogView4 model. It applies a rotary embedding on
|
||||
query and key vectors, but does not include spatial normalization.
|
||||
|
||||
The processor supports passing an attention mask for text tokens. The attention mask should have shape (batch_size,
|
||||
text_seq_length) where 1 indicates a non-padded token and 0 indicates a padded token.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
@@ -125,8 +129,10 @@ class CogView4AttnProcessor:
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
dtype = encoder_hidden_states.dtype
|
||||
|
||||
batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
|
||||
batch_size, image_seq_length, embed_dim = hidden_states.shape
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
@@ -142,9 +148,9 @@ class CogView4AttnProcessor:
|
||||
|
||||
# 2. QK normalization
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
query = attn.norm_q(query).to(dtype=dtype)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
key = attn.norm_k(key).to(dtype=dtype)
|
||||
|
||||
# 3. Rotational positional embeddings applied to latent stream
|
||||
if image_rotary_emb is not None:
|
||||
@@ -159,13 +165,14 @@ class CogView4AttnProcessor:
|
||||
|
||||
# 4. Attention
|
||||
if attention_mask is not None:
|
||||
text_attention_mask = attention_mask.float().to(query.device)
|
||||
actual_text_seq_length = text_attention_mask.size(1)
|
||||
new_attention_mask = torch.zeros((batch_size, text_seq_length + image_seq_length), device=query.device)
|
||||
new_attention_mask[:, :actual_text_seq_length] = text_attention_mask
|
||||
new_attention_mask = new_attention_mask.unsqueeze(2)
|
||||
attention_mask_matrix = new_attention_mask @ new_attention_mask.transpose(1, 2)
|
||||
attention_mask = (attention_mask_matrix > 0).unsqueeze(1).to(query.dtype)
|
||||
text_attn_mask = attention_mask
|
||||
assert text_attn_mask.dim() == 2, "the shape of text_attn_mask should be (batch_size, text_seq_length)"
|
||||
text_attn_mask = text_attn_mask.float().to(query.device)
|
||||
mix_attn_mask = torch.ones((batch_size, text_seq_length + image_seq_length), device=query.device)
|
||||
mix_attn_mask[:, :text_seq_length] = text_attn_mask
|
||||
mix_attn_mask = mix_attn_mask.unsqueeze(2)
|
||||
attn_mask_matrix = mix_attn_mask @ mix_attn_mask.transpose(1, 2)
|
||||
attention_mask = (attn_mask_matrix > 0).unsqueeze(1).to(query.dtype)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
@@ -183,9 +190,276 @@ class CogView4AttnProcessor:
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class CogView4TrainingAttnProcessor:
|
||||
"""
|
||||
Training Processor for implementing scaled dot-product attention for the CogView4 model. It applies a rotary
|
||||
embedding on query and key vectors, but does not include spatial normalization.
|
||||
|
||||
This processor differs from CogView4AttnProcessor in several important ways:
|
||||
1. It supports attention masking with variable sequence lengths for multi-resolution training
|
||||
2. It unpacks and repacks sequences for efficient training with variable sequence lengths when batch_flag is
|
||||
provided
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("CogView4AttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
latent_attn_mask: Optional[torch.Tensor] = None,
|
||||
text_attn_mask: Optional[torch.Tensor] = None,
|
||||
batch_flag: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[
|
||||
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
|
||||
] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
attn (`Attention`):
|
||||
The attention module.
|
||||
hidden_states (`torch.Tensor`):
|
||||
The input hidden states.
|
||||
encoder_hidden_states (`torch.Tensor`):
|
||||
The encoder hidden states for cross-attention.
|
||||
latent_attn_mask (`torch.Tensor`, *optional*):
|
||||
Mask for latent tokens where 0 indicates pad token and 1 indicates non-pad token. If None, full
|
||||
attention is used for all latent tokens. Note: the shape of latent_attn_mask is (batch_size,
|
||||
num_latent_tokens).
|
||||
text_attn_mask (`torch.Tensor`, *optional*):
|
||||
Mask for text tokens where 0 indicates pad token and 1 indicates non-pad token. If None, full attention
|
||||
is used for all text tokens.
|
||||
batch_flag (`torch.Tensor`, *optional*):
|
||||
Values from 0 to n-1 indicating which samples belong to the same batch. Samples with the same
|
||||
batch_flag are packed together. Example: [0, 1, 1, 2, 2] means sample 0 forms batch0, samples 1-2 form
|
||||
batch1, and samples 3-4 form batch2. If None, no packing is used.
|
||||
image_rotary_emb (`Tuple[torch.Tensor, torch.Tensor]` or `list[Tuple[torch.Tensor, torch.Tensor]]`, *optional*):
|
||||
The rotary embedding for the image part of the input.
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, torch.Tensor]`: The processed hidden states for both image and text streams.
|
||||
"""
|
||||
|
||||
# Get dimensions and device info
|
||||
batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
|
||||
batch_size, image_seq_length, embed_dim = hidden_states.shape
|
||||
dtype = encoder_hidden_states.dtype
|
||||
device = encoder_hidden_states.device
|
||||
latent_hidden_states = hidden_states
|
||||
# Combine text and image streams for joint processing
|
||||
mixed_hidden_states = torch.cat([encoder_hidden_states, latent_hidden_states], dim=1)
|
||||
|
||||
# 1. Construct attention mask and maybe packing input
|
||||
# Create default masks if not provided
|
||||
if text_attn_mask is None:
|
||||
text_attn_mask = torch.ones((batch_size, text_seq_length), dtype=torch.int32, device=device)
|
||||
if latent_attn_mask is None:
|
||||
latent_attn_mask = torch.ones((batch_size, image_seq_length), dtype=torch.int32, device=device)
|
||||
|
||||
# Validate mask shapes and types
|
||||
assert text_attn_mask.dim() == 2, "the shape of text_attn_mask should be (batch_size, text_seq_length)"
|
||||
assert text_attn_mask.dtype == torch.int32, "the dtype of text_attn_mask should be torch.int32"
|
||||
assert latent_attn_mask.dim() == 2, "the shape of latent_attn_mask should be (batch_size, num_latent_tokens)"
|
||||
assert latent_attn_mask.dtype == torch.int32, "the dtype of latent_attn_mask should be torch.int32"
|
||||
|
||||
# Create combined mask for text and image tokens
|
||||
mixed_attn_mask = torch.ones(
|
||||
(batch_size, text_seq_length + image_seq_length), dtype=torch.int32, device=device
|
||||
)
|
||||
mixed_attn_mask[:, :text_seq_length] = text_attn_mask
|
||||
mixed_attn_mask[:, text_seq_length:] = latent_attn_mask
|
||||
|
||||
# Convert mask to attention matrix format (where 1 means attend, 0 means don't attend)
|
||||
mixed_attn_mask_input = mixed_attn_mask.unsqueeze(2).to(dtype=dtype)
|
||||
attn_mask_matrix = mixed_attn_mask_input @ mixed_attn_mask_input.transpose(1, 2)
|
||||
|
||||
# Handle batch packing if enabled
|
||||
if batch_flag is not None:
|
||||
assert batch_flag.dim() == 1
|
||||
# Determine packed batch size based on batch_flag
|
||||
packing_batch_size = torch.max(batch_flag).item() + 1
|
||||
|
||||
# Calculate actual sequence lengths for each sample based on masks
|
||||
text_seq_length = torch.sum(text_attn_mask, dim=1)
|
||||
latent_seq_length = torch.sum(latent_attn_mask, dim=1)
|
||||
mixed_seq_length = text_seq_length + latent_seq_length
|
||||
|
||||
# Calculate packed sequence lengths for each packed batch
|
||||
mixed_seq_length_packed = [
|
||||
torch.sum(mixed_attn_mask[batch_flag == batch_idx]).item() for batch_idx in range(packing_batch_size)
|
||||
]
|
||||
|
||||
assert len(mixed_seq_length_packed) == packing_batch_size
|
||||
|
||||
# Pack sequences by removing padding tokens
|
||||
mixed_attn_mask_flatten = mixed_attn_mask.flatten(0, 1)
|
||||
mixed_hidden_states_flatten = mixed_hidden_states.flatten(0, 1)
|
||||
mixed_hidden_states_unpad = mixed_hidden_states_flatten[mixed_attn_mask_flatten == 1]
|
||||
assert torch.sum(mixed_seq_length) == mixed_hidden_states_unpad.shape[0]
|
||||
|
||||
# Split the unpadded sequence into packed batches
|
||||
mixed_hidden_states_packed = torch.split(mixed_hidden_states_unpad, mixed_seq_length_packed)
|
||||
|
||||
# Re-pad to create packed batches with right-side padding
|
||||
mixed_hidden_states_packed_padded = torch.nn.utils.rnn.pad_sequence(
|
||||
mixed_hidden_states_packed,
|
||||
batch_first=True,
|
||||
padding_value=0.0,
|
||||
padding_side="right",
|
||||
)
|
||||
|
||||
# Create attention mask for packed batches
|
||||
l = mixed_hidden_states_packed_padded.shape[1]
|
||||
attn_mask_matrix = torch.zeros(
|
||||
(packing_batch_size, l, l),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Fill attention mask with block diagonal matrices
|
||||
# This ensures that tokens can only attend to other tokens within the same original sample
|
||||
for idx, mask in enumerate(attn_mask_matrix):
|
||||
seq_lengths = mixed_seq_length[batch_flag == idx]
|
||||
offset = 0
|
||||
for length in seq_lengths:
|
||||
# Create a block of 1s for each sample in the packed batch
|
||||
mask[offset : offset + length, offset : offset + length] = 1
|
||||
offset += length
|
||||
|
||||
attn_mask_matrix = attn_mask_matrix.to(dtype=torch.bool)
|
||||
attn_mask_matrix = attn_mask_matrix.unsqueeze(1) # Add attention head dim
|
||||
attention_mask = attn_mask_matrix
|
||||
|
||||
# Prepare hidden states for attention computation
|
||||
if batch_flag is None:
|
||||
# If no packing, just combine text and image tokens
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
else:
|
||||
# If packing, use the packed sequence
|
||||
hidden_states = mixed_hidden_states_packed_padded
|
||||
|
||||
# 2. QKV projections - convert hidden states to query, key, value
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
# Reshape for multi-head attention: [batch, seq_len, heads*dim] -> [batch, heads, seq_len, dim]
|
||||
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
|
||||
# 3. QK normalization - apply layer norm to queries and keys if configured
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query).to(dtype=dtype)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key).to(dtype=dtype)
|
||||
|
||||
# 4. Apply rotary positional embeddings to image tokens only
|
||||
if image_rotary_emb is not None:
|
||||
from ..embeddings import apply_rotary_emb
|
||||
|
||||
if batch_flag is None:
|
||||
# Apply RoPE only to image tokens (after text tokens)
|
||||
query[:, :, text_seq_length:, :] = apply_rotary_emb(
|
||||
query[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
|
||||
)
|
||||
key[:, :, text_seq_length:, :] = apply_rotary_emb(
|
||||
key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
|
||||
)
|
||||
else:
|
||||
# For packed batches, need to carefully apply RoPE to appropriate tokens
|
||||
assert query.shape[0] == packing_batch_size
|
||||
assert key.shape[0] == packing_batch_size
|
||||
assert len(image_rotary_emb) == batch_size
|
||||
|
||||
rope_idx = 0
|
||||
for idx in range(packing_batch_size):
|
||||
offset = 0
|
||||
# Get text and image sequence lengths for samples in this packed batch
|
||||
text_seq_length_bi = text_seq_length[batch_flag == idx]
|
||||
latent_seq_length_bi = latent_seq_length[batch_flag == idx]
|
||||
|
||||
# Apply RoPE to each image segment in the packed sequence
|
||||
for tlen, llen in zip(text_seq_length_bi, latent_seq_length_bi):
|
||||
mlen = tlen + llen
|
||||
# Apply RoPE only to image tokens (after text tokens)
|
||||
query[idx, :, offset + tlen : offset + mlen, :] = apply_rotary_emb(
|
||||
query[idx, :, offset + tlen : offset + mlen, :],
|
||||
image_rotary_emb[rope_idx],
|
||||
use_real_unbind_dim=-2,
|
||||
)
|
||||
key[idx, :, offset + tlen : offset + mlen, :] = apply_rotary_emb(
|
||||
key[idx, :, offset + tlen : offset + mlen, :],
|
||||
image_rotary_emb[rope_idx],
|
||||
use_real_unbind_dim=-2,
|
||||
)
|
||||
offset += mlen
|
||||
rope_idx += 1
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
# Reshape back: [batch, heads, seq_len, dim] -> [batch, seq_len, heads*dim]
|
||||
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
||||
hidden_states = hidden_states.type_as(query)
|
||||
|
||||
# 5. Output projection - project attention output to model dimension
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
# Split the output back into text and image streams
|
||||
if batch_flag is None:
|
||||
# Simple split for non-packed case
|
||||
encoder_hidden_states, hidden_states = hidden_states.split(
|
||||
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
||||
)
|
||||
else:
|
||||
# For packed case: need to unpack, split text/image, then restore to original shapes
|
||||
# First, unpad the sequence based on the packed sequence lengths
|
||||
hidden_states_unpad = torch.nn.utils.rnn.unpad_sequence(
|
||||
hidden_states,
|
||||
lengths=torch.tensor(mixed_seq_length_packed),
|
||||
batch_first=True,
|
||||
)
|
||||
# Concatenate all unpadded sequences
|
||||
hidden_states_flatten = torch.cat(hidden_states_unpad, dim=0)
|
||||
# Split by original sample sequence lengths
|
||||
hidden_states_unpack = torch.split(hidden_states_flatten, mixed_seq_length.tolist())
|
||||
assert len(hidden_states_unpack) == batch_size
|
||||
|
||||
# Further split each sample's sequence into text and image parts
|
||||
hidden_states_unpack = [
|
||||
torch.split(h, [tlen, llen])
|
||||
for h, tlen, llen in zip(hidden_states_unpack, text_seq_length, latent_seq_length)
|
||||
]
|
||||
# Separate text and image sequences
|
||||
encoder_hidden_states_unpad = [h[0] for h in hidden_states_unpack]
|
||||
hidden_states_unpad = [h[1] for h in hidden_states_unpack]
|
||||
|
||||
# Update the original tensors with the processed values, respecting the attention masks
|
||||
for idx in range(batch_size):
|
||||
# Place unpacked text tokens back in the encoder_hidden_states tensor
|
||||
encoder_hidden_states[idx][text_attn_mask[idx] == 1] = encoder_hidden_states_unpad[idx]
|
||||
# Place unpacked image tokens back in the latent_hidden_states tensor
|
||||
latent_hidden_states[idx][latent_attn_mask[idx] == 1] = hidden_states_unpad[idx]
|
||||
|
||||
# Update the output hidden states
|
||||
hidden_states = latent_hidden_states
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class CogView4TransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self, dim: int = 2560, num_attention_heads: int = 64, attention_head_dim: int = 40, time_embed_dim: int = 512
|
||||
self,
|
||||
dim: int = 2560,
|
||||
num_attention_heads: int = 64,
|
||||
attention_head_dim: int = 40,
|
||||
time_embed_dim: int = 512,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -213,9 +487,11 @@ class CogView4TransformerBlock(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
image_rotary_emb: Optional[
|
||||
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
|
||||
] = None,
|
||||
attention_mask: Optional[Dict[str, torch.Tensor]] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> torch.Tensor:
|
||||
# 1. Timestep conditioning
|
||||
(
|
||||
@@ -232,12 +508,14 @@ class CogView4TransformerBlock(nn.Module):
|
||||
) = self.norm1(hidden_states, encoder_hidden_states, temb)
|
||||
|
||||
# 2. Attention
|
||||
if attention_kwargs is None:
|
||||
attention_kwargs = {}
|
||||
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
**attention_kwargs,
|
||||
)
|
||||
hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1)
|
||||
encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1)
|
||||
@@ -402,7 +680,9 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
image_rotary_emb: Optional[
|
||||
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
|
||||
] = None,
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
@@ -422,7 +702,8 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
|
||||
batch_size, num_channels, height, width = hidden_states.shape
|
||||
|
||||
# 1. RoPE
|
||||
image_rotary_emb = self.rope(hidden_states)
|
||||
if image_rotary_emb is None:
|
||||
image_rotary_emb = self.rope(hidden_states)
|
||||
|
||||
# 2. Patch & Timestep embeddings
|
||||
p = self.config.patch_size
|
||||
@@ -438,11 +719,22 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
|
||||
for block in self.transformer_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
||||
block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
attention_mask,
|
||||
attention_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
attention_mask,
|
||||
attention_kwargs,
|
||||
)
|
||||
|
||||
# 4. Output norm & projection
|
||||
|
||||
@@ -21,22 +21,22 @@ import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...models.attention import FeedForward
|
||||
from ...models.attention_processor import (
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.import_utils import is_torch_npu_available
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import (
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
FluxAttnProcessor2_0,
|
||||
FluxAttnProcessor2_0_NPU,
|
||||
FusedFluxAttnProcessor2_0,
|
||||
)
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.import_utils import is_torch_npu_available
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -275,7 +275,14 @@ class HiDreamAttnProcessor:
|
||||
|
||||
# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
|
||||
class MoEGate(nn.Module):
|
||||
def __init__(self, embed_dim, num_routed_experts=4, num_activated_experts=2, aux_loss_alpha=0.01):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
num_routed_experts=4,
|
||||
num_activated_experts=2,
|
||||
aux_loss_alpha=0.01,
|
||||
_force_inference_output=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.top_k = num_activated_experts
|
||||
self.n_routed_experts = num_routed_experts
|
||||
@@ -289,9 +296,10 @@ class MoEGate(nn.Module):
|
||||
self.gating_dim = embed_dim
|
||||
self.weight = nn.Parameter(torch.randn(self.n_routed_experts, self.gating_dim) / embed_dim**0.5)
|
||||
|
||||
self._force_inference_output = _force_inference_output
|
||||
|
||||
def forward(self, hidden_states):
|
||||
bsz, seq_len, h = hidden_states.shape
|
||||
# print(bsz, seq_len, h)
|
||||
### compute gating score
|
||||
hidden_states = hidden_states.view(-1, h)
|
||||
logits = F.linear(hidden_states, self.weight, None)
|
||||
@@ -309,7 +317,7 @@ class MoEGate(nn.Module):
|
||||
topk_weight = topk_weight / denominator
|
||||
|
||||
### expert-level computation auxiliary loss
|
||||
if self.training and self.alpha > 0.0:
|
||||
if self.training and self.alpha > 0.0 and not self._force_inference_output:
|
||||
scores_for_aux = scores
|
||||
aux_topk = self.top_k
|
||||
# always compute aux loss based on the naive greedy topk method
|
||||
@@ -341,14 +349,19 @@ class MOEFeedForwardSwiGLU(nn.Module):
|
||||
hidden_dim: int,
|
||||
num_routed_experts: int,
|
||||
num_activated_experts: int,
|
||||
_force_inference_output: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.shared_experts = HiDreamImageFeedForwardSwiGLU(dim, hidden_dim // 2)
|
||||
self.experts = nn.ModuleList(
|
||||
[HiDreamImageFeedForwardSwiGLU(dim, hidden_dim) for i in range(num_routed_experts)]
|
||||
)
|
||||
self._force_inference_output = _force_inference_output
|
||||
self.gate = MoEGate(
|
||||
embed_dim=dim, num_routed_experts=num_routed_experts, num_activated_experts=num_activated_experts
|
||||
embed_dim=dim,
|
||||
num_routed_experts=num_routed_experts,
|
||||
num_activated_experts=num_activated_experts,
|
||||
_force_inference_output=_force_inference_output,
|
||||
)
|
||||
self.num_activated_experts = num_activated_experts
|
||||
|
||||
@@ -359,7 +372,7 @@ class MOEFeedForwardSwiGLU(nn.Module):
|
||||
topk_idx, topk_weight, aux_loss = self.gate(x)
|
||||
x = x.view(-1, x.shape[-1])
|
||||
flat_topk_idx = topk_idx.view(-1)
|
||||
if self.training:
|
||||
if self.training and not self._force_inference_output:
|
||||
x = x.repeat_interleave(self.num_activated_experts, dim=0)
|
||||
y = torch.empty_like(x, dtype=wtype)
|
||||
for i, expert in enumerate(self.experts):
|
||||
@@ -413,6 +426,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module):
|
||||
attention_head_dim: int,
|
||||
num_routed_experts: int = 4,
|
||||
num_activated_experts: int = 2,
|
||||
_force_inference_output: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
@@ -436,6 +450,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module):
|
||||
hidden_dim=4 * dim,
|
||||
num_routed_experts=num_routed_experts,
|
||||
num_activated_experts=num_activated_experts,
|
||||
_force_inference_output=_force_inference_output,
|
||||
)
|
||||
else:
|
||||
self.ff_i = HiDreamImageFeedForwardSwiGLU(dim=dim, hidden_dim=4 * dim)
|
||||
@@ -480,6 +495,7 @@ class HiDreamImageTransformerBlock(nn.Module):
|
||||
attention_head_dim: int,
|
||||
num_routed_experts: int = 4,
|
||||
num_activated_experts: int = 2,
|
||||
_force_inference_output: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
@@ -504,6 +520,7 @@ class HiDreamImageTransformerBlock(nn.Module):
|
||||
hidden_dim=4 * dim,
|
||||
num_routed_experts=num_routed_experts,
|
||||
num_activated_experts=num_activated_experts,
|
||||
_force_inference_output=_force_inference_output,
|
||||
)
|
||||
else:
|
||||
self.ff_i = HiDreamImageFeedForwardSwiGLU(dim=dim, hidden_dim=4 * dim)
|
||||
@@ -606,6 +623,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
axes_dims_rope: Tuple[int, int] = (32, 32),
|
||||
max_resolution: Tuple[int, int] = (128, 128),
|
||||
llama_layers: List[int] = None,
|
||||
force_inference_output: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels or in_channels
|
||||
@@ -629,6 +647,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_routed_experts=num_routed_experts,
|
||||
num_activated_experts=num_activated_experts,
|
||||
_force_inference_output=force_inference_output,
|
||||
)
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
@@ -644,6 +663,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_routed_experts=num_routed_experts,
|
||||
num_activated_experts=num_activated_experts,
|
||||
_force_inference_output=force_inference_output,
|
||||
)
|
||||
)
|
||||
for _ in range(num_single_layers)
|
||||
@@ -662,7 +682,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]:
|
||||
if is_training:
|
||||
if is_training and not self.config.force_inference_output:
|
||||
B, S, F = x.shape
|
||||
C = F // (self.config.patch_size * self.config.patch_size)
|
||||
x = (
|
||||
@@ -771,7 +791,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
deprecation_message = "The `encoder_hidden_states` argument is deprecated. Please use `encoder_hidden_states_t5` and `encoder_hidden_states_llama3` instead."
|
||||
deprecate("encoder_hidden_states", "0.34.0", deprecation_message)
|
||||
deprecate("encoder_hidden_states", "0.35.0", deprecation_message)
|
||||
encoder_hidden_states_t5 = encoder_hidden_states[0]
|
||||
encoder_hidden_states_llama3 = encoder_hidden_states[1]
|
||||
|
||||
@@ -779,7 +799,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
deprecation_message = (
|
||||
"Passing `img_ids` and `img_sizes` with unpachified `hidden_states` is deprecated and will be ignored."
|
||||
)
|
||||
deprecate("img_ids", "0.34.0", deprecation_message)
|
||||
deprecate("img_ids", "0.35.0", deprecation_message)
|
||||
|
||||
if hidden_states_masks is not None and (img_ids is None or img_sizes is None):
|
||||
raise ValueError("if `hidden_states_masks` is passed, `img_ids` and `img_sizes` must also be passed.")
|
||||
|
||||
@@ -1068,17 +1068,15 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
latent_sequence_length = hidden_states.shape[1]
|
||||
condition_sequence_length = encoder_hidden_states.shape[1]
|
||||
sequence_length = latent_sequence_length + condition_sequence_length
|
||||
attention_mask = torch.zeros(
|
||||
attention_mask = torch.ones(
|
||||
batch_size, sequence_length, device=hidden_states.device, dtype=torch.bool
|
||||
) # [B, N]
|
||||
|
||||
effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,]
|
||||
effective_sequence_length = latent_sequence_length + effective_condition_sequence_length
|
||||
|
||||
for i in range(batch_size):
|
||||
attention_mask[i, : effective_sequence_length[i]] = True
|
||||
# [B, 1, 1, N], for broadcasting across attention heads
|
||||
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
|
||||
indices = torch.arange(sequence_length, device=hidden_states.device).unsqueeze(0) # [1, N]
|
||||
mask_indices = indices >= effective_sequence_length.unsqueeze(1) # [B, N]
|
||||
attention_mask = attention_mask.masked_fill(mask_indices, False)
|
||||
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, N]
|
||||
|
||||
# 4. Transformer blocks
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
@@ -18,19 +18,19 @@ import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin
|
||||
from ...models.attention import FeedForward, JointTransformerBlock
|
||||
from ...models.attention_processor import (
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward, JointTransformerBlock
|
||||
from ..attention_processor import (
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
FusedJointAttnProcessor2_0,
|
||||
JointAttnProcessor2_0,
|
||||
)
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -358,7 +358,7 @@ class KUpsample2D(nn.Module):
|
||||
|
||||
class CogVideoXUpsample3D(nn.Module):
|
||||
r"""
|
||||
A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase.
|
||||
A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper release.
|
||||
|
||||
Args:
|
||||
in_channels (`int`):
|
||||
|
||||
@@ -47,6 +47,7 @@ else:
|
||||
"AutoPipelineForInpainting",
|
||||
"AutoPipelineForText2Image",
|
||||
]
|
||||
_import_structure["modular_pipeline"] = ["ModularLoader"]
|
||||
_import_structure["consistency_models"] = ["ConsistencyModelPipeline"]
|
||||
_import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"]
|
||||
_import_structure["ddim"] = ["DDIMPipeline"]
|
||||
@@ -329,6 +330,8 @@ else:
|
||||
"StableDiffusionXLInpaintPipeline",
|
||||
"StableDiffusionXLInstructPix2PixPipeline",
|
||||
"StableDiffusionXLPipeline",
|
||||
"StableDiffusionXLModularLoader",
|
||||
"StableDiffusionXLAutoPipeline",
|
||||
]
|
||||
)
|
||||
_import_structure["stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"]
|
||||
@@ -478,6 +481,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .deprecated import KarrasVePipeline, LDMPipeline, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline
|
||||
from .dit import DiTPipeline
|
||||
from .latent_diffusion import LDMSuperResolutionPipeline
|
||||
from .modular_pipeline import ModularLoader
|
||||
from .pipeline_utils import (
|
||||
AudioPipelineOutput,
|
||||
DiffusionPipeline,
|
||||
@@ -699,9 +703,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .stable_diffusion_safe import StableDiffusionPipelineSafe
|
||||
from .stable_diffusion_sag import StableDiffusionSAGPipeline
|
||||
from .stable_diffusion_xl import (
|
||||
StableDiffusionXLAutoPipeline,
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
StableDiffusionXLInstructPix2PixPipeline,
|
||||
StableDiffusionXLModularLoader,
|
||||
StableDiffusionXLPipeline,
|
||||
)
|
||||
from .stable_video_diffusion import StableVideoDiffusionPipeline
|
||||
|
||||
@@ -514,7 +514,7 @@ class AllegroPipeline(DiffusionPipeline):
|
||||
# &
|
||||
caption = re.sub(r"&", "", caption)
|
||||
|
||||
# ip adresses:
|
||||
# ip addresses:
|
||||
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
|
||||
|
||||
# article ids:
|
||||
|
||||
@@ -246,14 +246,15 @@ def _get_connected_pipeline(pipeline_cls):
|
||||
return _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, pipeline_cls.__name__, throw_error_if_not_exist=False)
|
||||
|
||||
|
||||
def _get_task_class(mapping, pipeline_class_name, throw_error_if_not_exist: bool = True):
|
||||
def get_model(pipeline_class_name):
|
||||
for task_mapping in SUPPORTED_TASKS_MAPPINGS:
|
||||
for model_name, pipeline in task_mapping.items():
|
||||
if pipeline.__name__ == pipeline_class_name:
|
||||
return model_name
|
||||
def _get_model(pipeline_class_name):
|
||||
for task_mapping in SUPPORTED_TASKS_MAPPINGS:
|
||||
for model_name, pipeline in task_mapping.items():
|
||||
if pipeline.__name__ == pipeline_class_name:
|
||||
return model_name
|
||||
|
||||
model_name = get_model(pipeline_class_name)
|
||||
|
||||
def _get_task_class(mapping, pipeline_class_name, throw_error_if_not_exist: bool = True):
|
||||
model_name = _get_model(pipeline_class_name)
|
||||
|
||||
if model_name is not None:
|
||||
task_class = mapping.get(model_name, None)
|
||||
|
||||
@@ -0,0 +1,860 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from itertools import combinations
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import (
|
||||
is_accelerate_available,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate.hooks import ModelHook, add_hook_to_module, remove_hook_from_module
|
||||
from accelerate.state import PartialState
|
||||
from accelerate.utils import send_to_device
|
||||
from accelerate.utils.memory import clear_device_cache
|
||||
from accelerate.utils.modeling import convert_file_size_to_int
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# YiYi Notes: copied from modeling_utils.py (decide later where to put this)
|
||||
def get_memory_footprint(self, return_buffers=True):
|
||||
r"""
|
||||
Get the memory footprint of a model. This will return the memory footprint of the current model in bytes. Useful to
|
||||
benchmark the memory footprint of the current model and design some tests. Solution inspired from the PyTorch
|
||||
discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2
|
||||
|
||||
Arguments:
|
||||
return_buffers (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers are
|
||||
tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch norm
|
||||
layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2
|
||||
"""
|
||||
mem = sum([param.nelement() * param.element_size() for param in self.parameters()])
|
||||
if return_buffers:
|
||||
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()])
|
||||
mem = mem + mem_bufs
|
||||
return mem
|
||||
|
||||
|
||||
class CustomOffloadHook(ModelHook):
|
||||
"""
|
||||
A hook that offloads a model on the CPU until its forward pass is called. It ensures the model and its inputs are
|
||||
on the given device. Optionally offloads other models to the CPU before the forward pass is called.
|
||||
|
||||
Args:
|
||||
execution_device(`str`, `int` or `torch.device`, *optional*):
|
||||
The device on which the model should be executed. Will default to the MPS device if it's available, then
|
||||
GPU 0 if there is a GPU, and finally to the CPU.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
execution_device: Optional[Union[str, int, torch.device]] = None,
|
||||
other_hooks: Optional[List["UserCustomOffloadHook"]] = None,
|
||||
offload_strategy: Optional["AutoOffloadStrategy"] = None,
|
||||
):
|
||||
self.execution_device = execution_device if execution_device is not None else PartialState().default_device
|
||||
self.other_hooks = other_hooks
|
||||
self.offload_strategy = offload_strategy
|
||||
self.model_id = None
|
||||
|
||||
def set_strategy(self, offload_strategy: "AutoOffloadStrategy"):
|
||||
self.offload_strategy = offload_strategy
|
||||
|
||||
def add_other_hook(self, hook: "UserCustomOffloadHook"):
|
||||
"""
|
||||
Add a hook to the list of hooks to consider for offloading.
|
||||
"""
|
||||
if self.other_hooks is None:
|
||||
self.other_hooks = []
|
||||
self.other_hooks.append(hook)
|
||||
|
||||
def init_hook(self, module):
|
||||
return module.to("cpu")
|
||||
|
||||
def pre_forward(self, module, *args, **kwargs):
|
||||
if module.device != self.execution_device:
|
||||
if self.other_hooks is not None:
|
||||
hooks_to_offload = [hook for hook in self.other_hooks if hook.model.device == self.execution_device]
|
||||
# offload all other hooks
|
||||
start_time = time.perf_counter()
|
||||
if self.offload_strategy is not None:
|
||||
hooks_to_offload = self.offload_strategy(
|
||||
hooks=hooks_to_offload,
|
||||
model_id=self.model_id,
|
||||
model=module,
|
||||
execution_device=self.execution_device,
|
||||
)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(
|
||||
f" time taken to apply offload strategy for {self.model_id}: {(end_time - start_time):.2f} seconds"
|
||||
)
|
||||
|
||||
for hook in hooks_to_offload:
|
||||
logger.info(
|
||||
f"moving {self.model_id} to {self.execution_device}, offloading {hook.model_id} to cpu"
|
||||
)
|
||||
hook.offload()
|
||||
|
||||
if hooks_to_offload:
|
||||
clear_device_cache()
|
||||
module.to(self.execution_device)
|
||||
return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device)
|
||||
|
||||
|
||||
class UserCustomOffloadHook:
|
||||
"""
|
||||
A simple hook grouping a model and a `CustomOffloadHook`, which provides easy APIs for to call the init method of
|
||||
the hook or remove it entirely.
|
||||
"""
|
||||
|
||||
def __init__(self, model_id, model, hook):
|
||||
self.model_id = model_id
|
||||
self.model = model
|
||||
self.hook = hook
|
||||
|
||||
def offload(self):
|
||||
self.hook.init_hook(self.model)
|
||||
|
||||
def attach(self):
|
||||
add_hook_to_module(self.model, self.hook)
|
||||
self.hook.model_id = self.model_id
|
||||
|
||||
def remove(self):
|
||||
remove_hook_from_module(self.model)
|
||||
self.hook.model_id = None
|
||||
|
||||
def add_other_hook(self, hook: "UserCustomOffloadHook"):
|
||||
self.hook.add_other_hook(hook)
|
||||
|
||||
|
||||
def custom_offload_with_hook(
|
||||
model_id: str,
|
||||
model: torch.nn.Module,
|
||||
execution_device: Union[str, int, torch.device] = None,
|
||||
offload_strategy: Optional["AutoOffloadStrategy"] = None,
|
||||
):
|
||||
hook = CustomOffloadHook(execution_device=execution_device, offload_strategy=offload_strategy)
|
||||
user_hook = UserCustomOffloadHook(model_id=model_id, model=model, hook=hook)
|
||||
user_hook.attach()
|
||||
return user_hook
|
||||
|
||||
|
||||
class AutoOffloadStrategy:
|
||||
"""
|
||||
Offload strategy that should be used with `CustomOffloadHook` to automatically offload models to the CPU based on
|
||||
the available memory on the device.
|
||||
"""
|
||||
|
||||
def __init__(self, memory_reserve_margin="3GB"):
|
||||
self.memory_reserve_margin = convert_file_size_to_int(memory_reserve_margin)
|
||||
|
||||
def __call__(self, hooks, model_id, model, execution_device):
|
||||
if len(hooks) == 0:
|
||||
return []
|
||||
|
||||
current_module_size = get_memory_footprint(model)
|
||||
|
||||
mem_on_device = torch.cuda.mem_get_info(execution_device.index)[0]
|
||||
mem_on_device = mem_on_device - self.memory_reserve_margin
|
||||
if current_module_size < mem_on_device:
|
||||
return []
|
||||
|
||||
min_memory_offload = current_module_size - mem_on_device
|
||||
logger.info(f" search for models to offload in order to free up {min_memory_offload / 1024**3:.2f} GB memory")
|
||||
|
||||
# exlucde models that's not currently loaded on the device
|
||||
module_sizes = dict(
|
||||
sorted(
|
||||
{hook.model_id: get_memory_footprint(hook.model) for hook in hooks}.items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True,
|
||||
)
|
||||
)
|
||||
|
||||
def search_best_candidate(module_sizes, min_memory_offload):
|
||||
"""
|
||||
search the optimal combination of models to offload to cpu, given a dictionary of module sizes and a
|
||||
minimum memory offload size. the combination of models should add up to the smallest modulesize that is
|
||||
larger than `min_memory_offload`
|
||||
"""
|
||||
model_ids = list(module_sizes.keys())
|
||||
best_candidate = None
|
||||
best_size = float("inf")
|
||||
for r in range(1, len(model_ids) + 1):
|
||||
for candidate_model_ids in combinations(model_ids, r):
|
||||
candidate_size = sum(
|
||||
module_sizes[candidate_model_id] for candidate_model_id in candidate_model_ids
|
||||
)
|
||||
if candidate_size < min_memory_offload:
|
||||
continue
|
||||
else:
|
||||
if best_candidate is None or candidate_size < best_size:
|
||||
best_candidate = candidate_model_ids
|
||||
best_size = candidate_size
|
||||
|
||||
return best_candidate
|
||||
|
||||
best_offload_model_ids = search_best_candidate(module_sizes, min_memory_offload)
|
||||
|
||||
if best_offload_model_ids is None:
|
||||
# if no combination is found, meaning that we cannot meet the memory requirement, offload all models
|
||||
logger.warning("no combination of models to offload to cpu is found, offloading all models")
|
||||
hooks_to_offload = hooks
|
||||
else:
|
||||
hooks_to_offload = [hook for hook in hooks if hook.model_id in best_offload_model_ids]
|
||||
|
||||
return hooks_to_offload
|
||||
|
||||
|
||||
|
||||
import uuid
|
||||
|
||||
|
||||
class ComponentsManager:
|
||||
def __init__(self):
|
||||
self.components = OrderedDict()
|
||||
self.added_time = OrderedDict() # Store when components were added
|
||||
self.collections = OrderedDict() # collection_name -> set of component_names
|
||||
self.model_hooks = None
|
||||
self._auto_offload_enabled = False
|
||||
|
||||
|
||||
def _get_by_collection(self, collection: str):
|
||||
"""
|
||||
Select components by collection name.
|
||||
"""
|
||||
selected_components = {}
|
||||
if collection in self.collections:
|
||||
component_ids = self.collections[collection]
|
||||
for component_id in component_ids:
|
||||
selected_components[component_id] = self.components[component_id]
|
||||
return selected_components
|
||||
|
||||
|
||||
def _get_by_load_id(self, load_id: str):
|
||||
"""
|
||||
Select components by its load_id.
|
||||
"""
|
||||
selected_components = {}
|
||||
for name, component in self.components.items():
|
||||
if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id == load_id:
|
||||
selected_components[name] = component
|
||||
return selected_components
|
||||
|
||||
|
||||
def add(self, name, component, collection: Optional[str] = None):
|
||||
|
||||
for comp_id, comp in self.components.items():
|
||||
if comp == component:
|
||||
logger.warning(f"Component '{name}' already exists in ComponentsManager")
|
||||
return comp_id
|
||||
|
||||
component_id = f"{name}_{uuid.uuid4()}"
|
||||
|
||||
if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null":
|
||||
components_with_same_load_id = self._get_by_load_id(component._diffusers_load_id)
|
||||
if components_with_same_load_id:
|
||||
existing = ", ".join(components_with_same_load_id.keys())
|
||||
logger.warning(
|
||||
f"Component '{name}' has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. "
|
||||
f"To remove a duplicate, call `components_manager.remove('<component_name>')`."
|
||||
)
|
||||
|
||||
|
||||
# add component to components manager
|
||||
self.components[component_id] = component
|
||||
self.added_time[component_id] = time.time()
|
||||
if collection:
|
||||
if collection not in self.collections:
|
||||
self.collections[collection] = set()
|
||||
self.collections[collection].add(component_id)
|
||||
|
||||
if self._auto_offload_enabled:
|
||||
self.enable_auto_cpu_offload(self._auto_offload_device)
|
||||
|
||||
logger.info(f"Added component '{name}' to ComponentsManager as '{component_id}'")
|
||||
return component_id
|
||||
|
||||
|
||||
def remove(self, name: Union[str, List[str]]):
|
||||
|
||||
if name not in self.components:
|
||||
logger.warning(f"Component '{name}' not found in ComponentsManager")
|
||||
return
|
||||
|
||||
self.components.pop(name)
|
||||
self.added_time.pop(name)
|
||||
|
||||
for collection in self.collections:
|
||||
if name in self.collections[collection]:
|
||||
self.collections[collection].remove(name)
|
||||
|
||||
if self._auto_offload_enabled:
|
||||
self.enable_auto_cpu_offload(self._auto_offload_device)
|
||||
|
||||
def get(self, names: Union[str, List[str]] = None, collection: Optional[str] = None, load_id: Optional[str] = None,
|
||||
as_name_component_tuples: bool = False):
|
||||
"""
|
||||
Select components by name with simple pattern matching.
|
||||
|
||||
Args:
|
||||
names: Component name(s) or pattern(s)
|
||||
Patterns:
|
||||
- "unet" : match any component with base name "unet" (e.g., unet_123abc)
|
||||
- "!unet" : everything except components with base name "unet"
|
||||
- "unet*" : anything with base name starting with "unet"
|
||||
- "!unet*" : anything with base name NOT starting with "unet"
|
||||
- "*unet*" : anything with base name containing "unet"
|
||||
- "!*unet*" : anything with base name NOT containing "unet"
|
||||
- "refiner|vae|unet" : anything with base name exactly matching "refiner", "vae", or "unet"
|
||||
- "!refiner|vae|unet" : anything with base name NOT exactly matching "refiner", "vae", or "unet"
|
||||
- "unet*|vae*" : anything with base name starting with "unet" OR starting with "vae"
|
||||
collection: Optional collection to filter by
|
||||
load_id: Optional load_id to filter by
|
||||
as_name_component_tuples: If True, returns a list of (name, component) tuples using base names
|
||||
instead of a dictionary with component IDs as keys
|
||||
|
||||
Returns:
|
||||
Dictionary mapping component IDs to components,
|
||||
or list of (base_name, component) tuples if as_name_component_tuples=True
|
||||
"""
|
||||
|
||||
if collection:
|
||||
if collection not in self.collections:
|
||||
logger.warning(f"Collection '{collection}' not found in ComponentsManager")
|
||||
return [] if as_name_component_tuples else {}
|
||||
components = self._get_by_collection(collection)
|
||||
else:
|
||||
components = self.components
|
||||
|
||||
if load_id:
|
||||
components = self._get_by_load_id(load_id)
|
||||
|
||||
# Helper to extract base name from component_id
|
||||
def get_base_name(component_id):
|
||||
parts = component_id.split('_')
|
||||
# If the last part looks like a UUID, remove it
|
||||
if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]:
|
||||
return '_'.join(parts[:-1])
|
||||
return component_id
|
||||
|
||||
if names is None:
|
||||
if as_name_component_tuples:
|
||||
return [(get_base_name(comp_id), comp) for comp_id, comp in components.items()]
|
||||
else:
|
||||
return components
|
||||
|
||||
# Create mapping from component_id to base_name for all components
|
||||
base_names = {comp_id: get_base_name(comp_id) for comp_id in components.keys()}
|
||||
|
||||
def matches_pattern(component_id, pattern, exact_match=False):
|
||||
"""
|
||||
Helper function to check if a component matches a pattern based on its base name.
|
||||
|
||||
Args:
|
||||
component_id: The component ID to check
|
||||
pattern: The pattern to match against
|
||||
exact_match: If True, only exact matches to base_name are considered
|
||||
"""
|
||||
base_name = base_names[component_id]
|
||||
|
||||
# Exact match with base name
|
||||
if exact_match:
|
||||
return pattern == base_name
|
||||
|
||||
# Prefix match (ends with *)
|
||||
elif pattern.endswith('*'):
|
||||
prefix = pattern[:-1]
|
||||
return base_name.startswith(prefix)
|
||||
|
||||
# Contains match (starts with *)
|
||||
elif pattern.startswith('*'):
|
||||
search = pattern[1:-1] if pattern.endswith('*') else pattern[1:]
|
||||
return search in base_name
|
||||
|
||||
# Exact match (no wildcards)
|
||||
else:
|
||||
return pattern == base_name
|
||||
|
||||
if isinstance(names, str):
|
||||
# Check if this is a "not" pattern
|
||||
is_not_pattern = names.startswith('!')
|
||||
if is_not_pattern:
|
||||
names = names[1:] # Remove the ! prefix
|
||||
|
||||
# Handle OR patterns (containing |)
|
||||
if '|' in names:
|
||||
terms = names.split('|')
|
||||
matches = {}
|
||||
|
||||
for comp_id, comp in components.items():
|
||||
# For OR patterns with exact names (no wildcards), we do exact matching on base names
|
||||
exact_match = all(not (term.startswith('*') or term.endswith('*')) for term in terms)
|
||||
|
||||
# Check if any of the terms match this component
|
||||
should_include = any(matches_pattern(comp_id, term, exact_match) for term in terms)
|
||||
|
||||
# Flip the decision if this is a NOT pattern
|
||||
if is_not_pattern:
|
||||
should_include = not should_include
|
||||
|
||||
if should_include:
|
||||
matches[comp_id] = comp
|
||||
|
||||
log_msg = "NOT " if is_not_pattern else ""
|
||||
match_type = "exactly matching" if exact_match else "matching any of patterns"
|
||||
logger.info(f"Getting components {log_msg}{match_type} {terms}: {list(matches.keys())}")
|
||||
|
||||
# Try exact match with a base name
|
||||
elif any(names == base_name for base_name in base_names.values()):
|
||||
# Find all components with this base name
|
||||
matches = {
|
||||
comp_id: comp for comp_id, comp in components.items()
|
||||
if (base_names[comp_id] == names) != is_not_pattern
|
||||
}
|
||||
|
||||
if is_not_pattern:
|
||||
logger.info(f"Getting all components except those with base name '{names}': {list(matches.keys())}")
|
||||
else:
|
||||
logger.info(f"Getting components with base name '{names}': {list(matches.keys())}")
|
||||
|
||||
# Prefix match (ends with *)
|
||||
elif names.endswith('*'):
|
||||
prefix = names[:-1]
|
||||
matches = {
|
||||
comp_id: comp for comp_id, comp in components.items()
|
||||
if base_names[comp_id].startswith(prefix) != is_not_pattern
|
||||
}
|
||||
if is_not_pattern:
|
||||
logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}")
|
||||
else:
|
||||
logger.info(f"Getting components starting with '{prefix}': {list(matches.keys())}")
|
||||
|
||||
# Contains match (starts with *)
|
||||
elif names.startswith('*'):
|
||||
search = names[1:-1] if names.endswith('*') else names[1:]
|
||||
matches = {
|
||||
comp_id: comp for comp_id, comp in components.items()
|
||||
if (search in base_names[comp_id]) != is_not_pattern
|
||||
}
|
||||
if is_not_pattern:
|
||||
logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}")
|
||||
else:
|
||||
logger.info(f"Getting components containing '{search}': {list(matches.keys())}")
|
||||
|
||||
# Substring match (no wildcards, but not an exact component name)
|
||||
elif any(names in base_name for base_name in base_names.values()):
|
||||
matches = {
|
||||
comp_id: comp for comp_id, comp in components.items()
|
||||
if (names in base_names[comp_id]) != is_not_pattern
|
||||
}
|
||||
if is_not_pattern:
|
||||
logger.info(f"Getting components NOT containing '{names}': {list(matches.keys())}")
|
||||
else:
|
||||
logger.info(f"Getting components containing '{names}': {list(matches.keys())}")
|
||||
|
||||
else:
|
||||
raise ValueError(f"Component or pattern '{names}' not found in ComponentsManager")
|
||||
|
||||
if not matches:
|
||||
raise ValueError(f"No components found matching pattern '{names}'")
|
||||
|
||||
if as_name_component_tuples:
|
||||
return [(base_names[comp_id], comp) for comp_id, comp in matches.items()]
|
||||
else:
|
||||
return matches
|
||||
|
||||
elif isinstance(names, list):
|
||||
results = {}
|
||||
for name in names:
|
||||
result = self.get(name, collection, load_id, as_name_component_tuples=False)
|
||||
results.update(result)
|
||||
|
||||
if as_name_component_tuples:
|
||||
return [(base_names[comp_id], comp) for comp_id, comp in results.items()]
|
||||
else:
|
||||
return results
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid type for names: {type(names)}")
|
||||
|
||||
def enable_auto_cpu_offload(self, device: Union[str, int, torch.device]="cuda", memory_reserve_margin="3GB"):
|
||||
for name, component in self.components.items():
|
||||
if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"):
|
||||
remove_hook_from_module(component, recurse=True)
|
||||
|
||||
self.disable_auto_cpu_offload()
|
||||
offload_strategy = AutoOffloadStrategy(memory_reserve_margin=memory_reserve_margin)
|
||||
device = torch.device(device)
|
||||
if device.index is None:
|
||||
device = torch.device(f"{device.type}:{0}")
|
||||
all_hooks = []
|
||||
for name, component in self.components.items():
|
||||
if isinstance(component, torch.nn.Module):
|
||||
hook = custom_offload_with_hook(name, component, device, offload_strategy=offload_strategy)
|
||||
all_hooks.append(hook)
|
||||
|
||||
for hook in all_hooks:
|
||||
other_hooks = [h for h in all_hooks if h is not hook]
|
||||
for other_hook in other_hooks:
|
||||
if other_hook.hook.execution_device == hook.hook.execution_device:
|
||||
hook.add_other_hook(other_hook)
|
||||
|
||||
self.model_hooks = all_hooks
|
||||
self._auto_offload_enabled = True
|
||||
self._auto_offload_device = device
|
||||
|
||||
def disable_auto_cpu_offload(self):
|
||||
if self.model_hooks is None:
|
||||
self._auto_offload_enabled = False
|
||||
return
|
||||
|
||||
for hook in self.model_hooks:
|
||||
hook.offload()
|
||||
hook.remove()
|
||||
if self.model_hooks:
|
||||
clear_device_cache()
|
||||
self.model_hooks = None
|
||||
self._auto_offload_enabled = False
|
||||
|
||||
# YiYi TODO: add quantization info
|
||||
def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = None) -> Optional[Dict[str, Any]]:
|
||||
"""Get comprehensive information about a component.
|
||||
|
||||
Args:
|
||||
name: Name of the component to get info for
|
||||
fields: Optional field(s) to return. Can be a string for single field or list of fields.
|
||||
If None, returns all fields.
|
||||
|
||||
Returns:
|
||||
Dictionary containing requested component metadata.
|
||||
If fields is specified, returns only those fields.
|
||||
If a single field is requested as string, returns just that field's value.
|
||||
"""
|
||||
if name not in self.components:
|
||||
raise ValueError(f"Component '{name}' not found in ComponentsManager")
|
||||
|
||||
component = self.components[name]
|
||||
|
||||
# Build complete info dict first
|
||||
info = {
|
||||
"model_id": name,
|
||||
"added_time": self.added_time[name],
|
||||
"collection": next((coll for coll, comps in self.collections.items() if name in comps), None),
|
||||
}
|
||||
|
||||
# Additional info for torch.nn.Module components
|
||||
if isinstance(component, torch.nn.Module):
|
||||
# Check for hook information
|
||||
has_hook = hasattr(component, "_hf_hook")
|
||||
execution_device = None
|
||||
if has_hook and hasattr(component._hf_hook, "execution_device"):
|
||||
execution_device = component._hf_hook.execution_device
|
||||
|
||||
info.update({
|
||||
"class_name": component.__class__.__name__,
|
||||
"size_gb": get_memory_footprint(component) / (1024**3),
|
||||
"adapters": None, # Default to None
|
||||
"has_hook": has_hook,
|
||||
"execution_device": execution_device,
|
||||
})
|
||||
|
||||
# Get adapters if applicable
|
||||
if hasattr(component, "peft_config"):
|
||||
info["adapters"] = list(component.peft_config.keys())
|
||||
|
||||
# Check for IP-Adapter scales
|
||||
if hasattr(component, "_load_ip_adapter_weights") and hasattr(component, "attn_processors"):
|
||||
processors = copy.deepcopy(component.attn_processors)
|
||||
# First check if any processor is an IP-Adapter
|
||||
processor_types = [v.__class__.__name__ for v in processors.values()]
|
||||
if any("IPAdapter" in ptype for ptype in processor_types):
|
||||
# Then get scales only from IP-Adapter processors
|
||||
scales = {
|
||||
k: v.scale
|
||||
for k, v in processors.items()
|
||||
if hasattr(v, "scale") and "IPAdapter" in v.__class__.__name__
|
||||
}
|
||||
if scales:
|
||||
info["ip_adapter"] = summarize_dict_by_value_and_parts(scales)
|
||||
|
||||
# If fields specified, filter info
|
||||
if fields is not None:
|
||||
if isinstance(fields, str):
|
||||
# Single field requested, return just that value
|
||||
return {fields: info.get(fields)}
|
||||
else:
|
||||
# List of fields requested, return dict with just those fields
|
||||
return {k: v for k, v in info.items() if k in fields}
|
||||
|
||||
return info
|
||||
|
||||
def __repr__(self):
|
||||
# Helper to get simple name without UUID
|
||||
def get_simple_name(name):
|
||||
# Extract the base name by splitting on underscore and taking first part
|
||||
# This assumes names are in format "name_uuid"
|
||||
parts = name.split('_')
|
||||
# If we have at least 2 parts and the last part looks like a UUID, remove it
|
||||
if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]:
|
||||
return '_'.join(parts[:-1])
|
||||
return name
|
||||
|
||||
# Extract load_id if available
|
||||
def get_load_id(component):
|
||||
if hasattr(component, "_diffusers_load_id"):
|
||||
return component._diffusers_load_id
|
||||
return "N/A"
|
||||
|
||||
# Format device info compactly
|
||||
def format_device(component, info):
|
||||
if not info["has_hook"]:
|
||||
return str(getattr(component, 'device', 'N/A'))
|
||||
else:
|
||||
device = str(getattr(component, 'device', 'N/A'))
|
||||
exec_device = str(info['execution_device'] or 'N/A')
|
||||
return f"{device}({exec_device})"
|
||||
|
||||
# Get all simple names to calculate width
|
||||
simple_names = [get_simple_name(id) for id in self.components.keys()]
|
||||
|
||||
# Get max length of load_ids for models
|
||||
load_ids = [
|
||||
get_load_id(component)
|
||||
for component in self.components.values()
|
||||
if isinstance(component, torch.nn.Module) and hasattr(component, "_diffusers_load_id")
|
||||
]
|
||||
max_load_id_len = max([15] + [len(str(lid)) for lid in load_ids]) if load_ids else 15
|
||||
|
||||
# Collection names
|
||||
collection_names = [
|
||||
next((coll for coll, comps in self.collections.items() if name in comps), "N/A")
|
||||
for name in self.components.keys()
|
||||
]
|
||||
|
||||
col_widths = {
|
||||
"name": max(15, max(len(name) for name in simple_names)),
|
||||
"class": max(25, max(len(component.__class__.__name__) for component in self.components.values())),
|
||||
"device": 15, # Reduced since using more compact format
|
||||
"dtype": 15,
|
||||
"size": 10,
|
||||
"load_id": max_load_id_len,
|
||||
"collection": max(10, max(len(str(c)) for c in collection_names))
|
||||
}
|
||||
|
||||
# Create the header lines
|
||||
sep_line = "=" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n"
|
||||
dash_line = "-" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n"
|
||||
|
||||
output = "Components:\n" + sep_line
|
||||
|
||||
# Separate components into models and others
|
||||
models = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)}
|
||||
others = {k: v for k, v in self.components.items() if not isinstance(v, torch.nn.Module)}
|
||||
|
||||
# Models section
|
||||
if models:
|
||||
output += "Models:\n" + dash_line
|
||||
# Column headers
|
||||
output += f"{'Name':<{col_widths['name']}} | {'Class':<{col_widths['class']}} | "
|
||||
output += f"{'Device':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | "
|
||||
output += f"{'Size (GB)':<{col_widths['size']}} | {'Load ID':<{col_widths['load_id']}} | Collection\n"
|
||||
output += dash_line
|
||||
|
||||
# Model entries
|
||||
for name, component in models.items():
|
||||
info = self.get_model_info(name)
|
||||
simple_name = get_simple_name(name)
|
||||
device_str = format_device(component, info)
|
||||
dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A"
|
||||
load_id = get_load_id(component)
|
||||
collection = info["collection"] or "N/A"
|
||||
|
||||
output += f"{simple_name:<{col_widths['name']}} | {info['class_name']:<{col_widths['class']}} | "
|
||||
output += f"{device_str:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | "
|
||||
output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {collection}\n"
|
||||
output += dash_line
|
||||
|
||||
# Other components section
|
||||
if others:
|
||||
if models: # Add extra newline if we had models section
|
||||
output += "\n"
|
||||
output += "Other Components:\n" + dash_line
|
||||
# Column headers for other components
|
||||
output += f"{'Name':<{col_widths['name']}} | {'Class':<{col_widths['class']}} | Collection\n"
|
||||
output += dash_line
|
||||
|
||||
# Other component entries
|
||||
for name, component in others.items():
|
||||
info = self.get_model_info(name)
|
||||
simple_name = get_simple_name(name)
|
||||
collection = info["collection"] or "N/A"
|
||||
|
||||
output += f"{simple_name:<{col_widths['name']}} | {component.__class__.__name__:<{col_widths['class']}} | {collection}\n"
|
||||
output += dash_line
|
||||
|
||||
# Add additional component info
|
||||
output += "\nAdditional Component Info:\n" + "=" * 50 + "\n"
|
||||
for name in self.components:
|
||||
info = self.get_model_info(name)
|
||||
if info is not None and (info.get("adapters") is not None or info.get("ip_adapter")):
|
||||
simple_name = get_simple_name(name)
|
||||
output += f"\n{simple_name}:\n"
|
||||
if info.get("adapters") is not None:
|
||||
output += f" Adapters: {info['adapters']}\n"
|
||||
if info.get("ip_adapter"):
|
||||
output += " IP-Adapter: Enabled\n"
|
||||
output += f" Added Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(info['added_time']))}\n"
|
||||
|
||||
return output
|
||||
|
||||
def from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs):
|
||||
"""
|
||||
Load components from a pretrained model and add them to the manager.
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path (str): The path or identifier of the pretrained model
|
||||
prefix (str, optional): Prefix to add to all component names loaded from this model.
|
||||
If provided, components will be named as "{prefix}_{component_name}"
|
||||
**kwargs: Additional arguments to pass to DiffusionPipeline.from_pretrained()
|
||||
"""
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
# YiYi TODO: extend AutoModel to support non-diffusers models
|
||||
if subfolder:
|
||||
from ..models import AutoModel
|
||||
component = AutoModel.from_pretrained(pretrained_model_name_or_path, subfolder=subfolder, **kwargs)
|
||||
component_name = f"{prefix}_{subfolder}" if prefix else subfolder
|
||||
if component_name not in self.components:
|
||||
self.add(component_name, component)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Component '{component_name}' already exists in ComponentsManager and will not be added. To add it, either:\n"
|
||||
f"1. remove the existing component with remove('{component_name}')\n"
|
||||
f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')"
|
||||
)
|
||||
else:
|
||||
from ..pipelines.pipeline_utils import DiffusionPipeline
|
||||
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
for name, component in pipe.components.items():
|
||||
|
||||
if component is None:
|
||||
continue
|
||||
|
||||
# Add prefix if specified
|
||||
component_name = f"{prefix}_{name}" if prefix else name
|
||||
|
||||
if component_name not in self.components:
|
||||
self.add(component_name, component)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Component '{component_name}' already exists in ComponentsManager and will not be added. To add it, either:\n"
|
||||
f"1. remove the existing component with remove('{component_name}')\n"
|
||||
f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')"
|
||||
)
|
||||
|
||||
def get_one(self, name: Optional[str] = None, collection: Optional[str] = None, load_id: Optional[str] = None) -> Any:
|
||||
"""
|
||||
Get a single component by name. Raises an error if multiple components match or none are found.
|
||||
|
||||
Args:
|
||||
name: Component name or pattern
|
||||
collection: Optional collection to filter by
|
||||
load_id: Optional load_id to filter by
|
||||
|
||||
Returns:
|
||||
A single component
|
||||
|
||||
Raises:
|
||||
ValueError: If no components match or multiple components match
|
||||
"""
|
||||
results = self.get(name, collection, load_id)
|
||||
|
||||
if not results:
|
||||
raise ValueError(f"No components found matching '{name}'")
|
||||
|
||||
if len(results) > 1:
|
||||
raise ValueError(f"Multiple components found matching '{name}': {list(results.keys())}")
|
||||
|
||||
return next(iter(results.values()))
|
||||
|
||||
def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Summarizes a dictionary by finding common prefixes that share the same value.
|
||||
|
||||
For a dictionary with dot-separated keys like:
|
||||
{
|
||||
'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': [0.6],
|
||||
'down_blocks.1.attentions.1.transformer_blocks.1.attn2.processor': [0.6],
|
||||
'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': [0.3],
|
||||
}
|
||||
|
||||
Returns a dictionary where keys are the shortest common prefixes and values are their shared values:
|
||||
{
|
||||
'down_blocks': [0.6],
|
||||
'up_blocks': [0.3]
|
||||
}
|
||||
"""
|
||||
# First group by values - convert lists to tuples to make them hashable
|
||||
value_to_keys = {}
|
||||
for key, value in d.items():
|
||||
value_tuple = tuple(value) if isinstance(value, list) else value
|
||||
if value_tuple not in value_to_keys:
|
||||
value_to_keys[value_tuple] = []
|
||||
value_to_keys[value_tuple].append(key)
|
||||
|
||||
def find_common_prefix(keys: List[str]) -> str:
|
||||
"""Find the shortest common prefix among a list of dot-separated keys."""
|
||||
if not keys:
|
||||
return ""
|
||||
if len(keys) == 1:
|
||||
return keys[0]
|
||||
|
||||
# Split all keys into parts
|
||||
key_parts = [k.split('.') for k in keys]
|
||||
|
||||
# Find how many initial parts are common
|
||||
common_length = 0
|
||||
for parts in zip(*key_parts):
|
||||
if len(set(parts)) == 1: # All parts at this position are the same
|
||||
common_length += 1
|
||||
else:
|
||||
break
|
||||
|
||||
if common_length == 0:
|
||||
return ""
|
||||
|
||||
# Return the common prefix
|
||||
return '.'.join(key_parts[0][:common_length])
|
||||
|
||||
# Create summary by finding common prefixes for each value group
|
||||
summary = {}
|
||||
for value_tuple, keys in value_to_keys.items():
|
||||
prefix = find_common_prefix(keys)
|
||||
if prefix: # Only add if we found a common prefix
|
||||
# Convert tuple back to list if it was originally a list
|
||||
value = list(value_tuple) if isinstance(d[keys[0]], list) else value_tuple
|
||||
summary[prefix] = value
|
||||
else:
|
||||
summary[""] = value # Use empty string if no common prefix
|
||||
|
||||
return summary
|
||||
@@ -912,12 +912,6 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
||||
)
|
||||
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
|
||||
|
||||
# Offload text encoder if `enable_model_cpu_offload` was enabled
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.text_encoder_2.to("cpu")
|
||||
@@ -931,6 +925,11 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
init_latents = image
|
||||
|
||||
else:
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
if self.vae.config.force_upcast:
|
||||
image = image.float()
|
||||
|
||||
@@ -867,12 +867,6 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
||||
)
|
||||
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
|
||||
|
||||
# Offload text encoder if `enable_model_cpu_offload` was enabled
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.text_encoder_2.to("cpu")
|
||||
@@ -886,6 +880,11 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
init_latents = image
|
||||
|
||||
else:
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
if self.vae.config.force_upcast:
|
||||
image = image.float()
|
||||
|
||||
@@ -484,7 +484,7 @@ class IFPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin):
|
||||
# &
|
||||
caption = re.sub(r"&", "", caption)
|
||||
|
||||
# ip adresses:
|
||||
# ip addresses:
|
||||
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
|
||||
|
||||
# article ids:
|
||||
|
||||
@@ -528,7 +528,7 @@ class IFImg2ImgPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin):
|
||||
# &
|
||||
caption = re.sub(r"&", "", caption)
|
||||
|
||||
# ip adresses:
|
||||
# ip addresses:
|
||||
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
|
||||
|
||||
# article ids:
|
||||
|
||||
@@ -281,7 +281,7 @@ class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline, StableDiffusionLoraLoa
|
||||
# &
|
||||
caption = re.sub(r"&", "", caption)
|
||||
|
||||
# ip adresses:
|
||||
# ip addresses:
|
||||
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
|
||||
|
||||
# article ids:
|
||||
|
||||
@@ -568,7 +568,7 @@ class IFInpaintingPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin):
|
||||
# &
|
||||
caption = re.sub(r"&", "", caption)
|
||||
|
||||
# ip adresses:
|
||||
# ip addresses:
|
||||
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
|
||||
|
||||
# article ids:
|
||||
|
||||
@@ -283,7 +283,7 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, StableDiffusionLora
|
||||
# &
|
||||
caption = re.sub(r"&", "", caption)
|
||||
|
||||
# ip adresses:
|
||||
# ip addresses:
|
||||
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
|
||||
|
||||
# article ids:
|
||||
|
||||
@@ -239,7 +239,7 @@ class IFSuperResolutionPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixi
|
||||
# &
|
||||
caption = re.sub(r"&", "", caption)
|
||||
|
||||
# ip adresses:
|
||||
# ip addresses:
|
||||
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
|
||||
|
||||
# article ids:
|
||||
|
||||
+1
-1
@@ -574,7 +574,7 @@ class StableDiffusionModelEditingPipeline(
|
||||
idxs_replace.append(76)
|
||||
idxs_replaces.append(idxs_replace)
|
||||
|
||||
# prepare batch: for each pair of setences, old context and new values
|
||||
# prepare batch: for each pair of sentences, old context and new values
|
||||
contexts, valuess = [], []
|
||||
for old_emb, new_emb, idxs_replace in zip(old_embs, new_embs, idxs_replaces):
|
||||
context = old_emb.detach()
|
||||
|
||||
@@ -490,14 +490,6 @@ class FluxPipeline(
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
||||
@@ -821,7 +813,7 @@ class FluxPipeline(
|
||||
(
|
||||
negative_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
_,
|
||||
negative_text_ids,
|
||||
) = self.encode_prompt(
|
||||
prompt=negative_prompt,
|
||||
prompt_2=negative_prompt_2,
|
||||
@@ -938,7 +930,7 @@ class FluxPipeline(
|
||||
guidance=guidance,
|
||||
pooled_projections=negative_pooled_prompt_embeds,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
txt_ids=text_ids,
|
||||
txt_ids=negative_text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
|
||||
@@ -800,17 +800,20 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
)
|
||||
height, width = control_image.shape[-2:]
|
||||
|
||||
control_image = retrieve_latents(self.vae.encode(control_image), generator=generator)
|
||||
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
# xlab controlnet has a input_hint_block and instantx controlnet does not
|
||||
controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True
|
||||
if self.controlnet.input_hint_block is None:
|
||||
control_image = retrieve_latents(self.vae.encode(control_image), generator=generator)
|
||||
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
|
||||
height_control_image, width_control_image = control_image.shape[2:]
|
||||
control_image = self._pack_latents(
|
||||
control_image,
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height_control_image,
|
||||
width_control_image,
|
||||
)
|
||||
height_control_image, width_control_image = control_image.shape[2:]
|
||||
control_image = self._pack_latents(
|
||||
control_image,
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height_control_image,
|
||||
width_control_image,
|
||||
)
|
||||
|
||||
if control_mode is not None:
|
||||
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
|
||||
@@ -819,7 +822,9 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
elif isinstance(self.controlnet, FluxMultiControlNetModel):
|
||||
control_images = []
|
||||
|
||||
for control_image_ in control_image:
|
||||
# xlab controlnet has a input_hint_block and instantx controlnet does not
|
||||
controlnet_blocks_repeat = False if self.controlnet.nets[0].input_hint_block is None else True
|
||||
for i, control_image_ in enumerate(control_image):
|
||||
control_image_ = self.prepare_image(
|
||||
image=control_image_,
|
||||
width=width,
|
||||
@@ -831,17 +836,18 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
)
|
||||
height, width = control_image_.shape[-2:]
|
||||
|
||||
control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator)
|
||||
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
if self.controlnet.nets[0].input_hint_block is None:
|
||||
control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator)
|
||||
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
|
||||
height_control_image, width_control_image = control_image_.shape[2:]
|
||||
control_image_ = self._pack_latents(
|
||||
control_image_,
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height_control_image,
|
||||
width_control_image,
|
||||
)
|
||||
height_control_image, width_control_image = control_image_.shape[2:]
|
||||
control_image_ = self._pack_latents(
|
||||
control_image_,
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height_control_image,
|
||||
width_control_image,
|
||||
)
|
||||
|
||||
control_images.append(control_image_)
|
||||
|
||||
@@ -955,6 +961,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
img_ids=latent_image_ids,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
controlnet_blocks_repeat=controlnet_blocks_repeat,
|
||||
)[0]
|
||||
|
||||
latents_dtype = latents.dtype
|
||||
|
||||
@@ -13,6 +13,7 @@ from transformers import (
|
||||
)
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import HiDreamImageLoraLoaderMixin
|
||||
from ...models import AutoencoderKL, HiDreamImageTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler
|
||||
from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
|
||||
@@ -142,7 +143,7 @@ def retrieve_timesteps(
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class HiDreamImagePipeline(DiffusionPipeline):
|
||||
class HiDreamImagePipeline(DiffusionPipeline, HiDreamImageLoraLoaderMixin):
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->text_encoder_4->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds_t5", "prompt_embeds_llama3", "pooled_prompt_embeds"]
|
||||
|
||||
@@ -822,13 +823,13 @@ class HiDreamImagePipeline(DiffusionPipeline):
|
||||
|
||||
if prompt_embeds is not None:
|
||||
deprecation_message = "The `prompt_embeds` argument is deprecated. Please use `prompt_embeds_t5` and `prompt_embeds_llama3` instead."
|
||||
deprecate("prompt_embeds", "0.34.0", deprecation_message)
|
||||
deprecate("prompt_embeds", "0.35.0", deprecation_message)
|
||||
prompt_embeds_t5 = prompt_embeds[0]
|
||||
prompt_embeds_llama3 = prompt_embeds[1]
|
||||
|
||||
if negative_prompt_embeds is not None:
|
||||
deprecation_message = "The `negative_prompt_embeds` argument is deprecated. Please use `negative_prompt_embeds_t5` and `negative_prompt_embeds_llama3` instead."
|
||||
deprecate("negative_prompt_embeds", "0.34.0", deprecation_message)
|
||||
deprecate("negative_prompt_embeds", "0.35.0", deprecation_message)
|
||||
negative_prompt_embeds_t5 = negative_prompt_embeds[0]
|
||||
negative_prompt_embeds_llama3 = negative_prompt_embeds[1]
|
||||
|
||||
|
||||
@@ -14,14 +14,13 @@
|
||||
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import (
|
||||
XLMRobertaTokenizer,
|
||||
)
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...models import UNet2DConditionModel, VQModel
|
||||
from ...schedulers import DDIMScheduler
|
||||
from ...utils import (
|
||||
@@ -95,15 +94,6 @@ def get_new_h_w(h, w, scale_factor=8):
|
||||
return new_h * scale_factor, new_w * scale_factor
|
||||
|
||||
|
||||
def prepare_image(pil_image, w=512, h=512):
|
||||
pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
|
||||
arr = np.array(pil_image.convert("RGB"))
|
||||
arr = arr.astype(np.float32) / 127.5 - 1
|
||||
arr = np.transpose(arr, [2, 0, 1])
|
||||
image = torch.from_numpy(arr).unsqueeze(0)
|
||||
return image
|
||||
|
||||
|
||||
class KandinskyImg2ImgPipeline(DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for image-to-image generation using Kandinsky
|
||||
@@ -143,7 +133,16 @@ class KandinskyImg2ImgPipeline(DiffusionPipeline):
|
||||
scheduler=scheduler,
|
||||
movq=movq,
|
||||
)
|
||||
self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
|
||||
self.movq_scale_factor = (
|
||||
2 ** (len(self.movq.config.block_out_channels) - 1) if getattr(self, "movq", None) else 8
|
||||
)
|
||||
movq_latent_channels = self.movq.config.latent_channels if getattr(self, "movq", None) else 4
|
||||
self.image_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.movq_scale_factor,
|
||||
vae_latent_channels=movq_latent_channels,
|
||||
resample="bicubic",
|
||||
reducing_gap=1,
|
||||
)
|
||||
|
||||
def get_timesteps(self, num_inference_steps, strength, device):
|
||||
# get the original timestep using init_timestep
|
||||
@@ -417,7 +416,7 @@ class KandinskyImg2ImgPipeline(DiffusionPipeline):
|
||||
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
|
||||
)
|
||||
|
||||
image = torch.cat([prepare_image(i, width, height) for i in image], dim=0)
|
||||
image = torch.cat([self.image_processor.preprocess(i, width, height) for i in image], dim=0)
|
||||
image = image.to(dtype=prompt_embeds.dtype, device=device)
|
||||
|
||||
latents = self.movq.encode(image)["latents"]
|
||||
@@ -498,13 +497,7 @@ class KandinskyImg2ImgPipeline(DiffusionPipeline):
|
||||
if output_type not in ["pt", "np", "pil"]:
|
||||
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
|
||||
|
||||
if output_type in ["np", "pil"]:
|
||||
image = image * 0.5 + 0.5
|
||||
image = image.clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
image = self.image_processor.postprocess(image, output_type)
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
@@ -14,11 +14,10 @@
|
||||
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...models import UNet2DConditionModel, VQModel
|
||||
from ...schedulers import DDPMScheduler
|
||||
from ...utils import (
|
||||
@@ -105,27 +104,6 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width
|
||||
def downscale_height_and_width(height, width, scale_factor=8):
|
||||
new_height = height // scale_factor**2
|
||||
if height % scale_factor**2 != 0:
|
||||
new_height += 1
|
||||
new_width = width // scale_factor**2
|
||||
if width % scale_factor**2 != 0:
|
||||
new_width += 1
|
||||
return new_height * scale_factor, new_width * scale_factor
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.prepare_image
|
||||
def prepare_image(pil_image, w=512, h=512):
|
||||
pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
|
||||
arr = np.array(pil_image.convert("RGB"))
|
||||
arr = arr.astype(np.float32) / 127.5 - 1
|
||||
arr = np.transpose(arr, [2, 0, 1])
|
||||
image = torch.from_numpy(arr).unsqueeze(0)
|
||||
return image
|
||||
|
||||
|
||||
class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for image-to-image generation using Kandinsky
|
||||
@@ -157,7 +135,14 @@ class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline):
|
||||
scheduler=scheduler,
|
||||
movq=movq,
|
||||
)
|
||||
self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
|
||||
movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) if getattr(self, "movq", None) else 8
|
||||
movq_latent_channels = self.movq.config.latent_channels if getattr(self, "movq", None) else 4
|
||||
self.image_processor = VaeImageProcessor(
|
||||
vae_scale_factor=movq_scale_factor,
|
||||
vae_latent_channels=movq_latent_channels,
|
||||
resample="bicubic",
|
||||
reducing_gap=1,
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.KandinskyImg2ImgPipeline.get_timesteps
|
||||
def get_timesteps(self, num_inference_steps, strength, device):
|
||||
@@ -316,7 +301,7 @@ class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline):
|
||||
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
|
||||
)
|
||||
|
||||
image = torch.cat([prepare_image(i, width, height) for i in image], dim=0)
|
||||
image = torch.cat([self.image_processor.preprocess(i, width, height) for i in image], dim=0)
|
||||
image = image.to(dtype=image_embeds.dtype, device=device)
|
||||
|
||||
latents = self.movq.encode(image)["latents"]
|
||||
@@ -324,7 +309,6 @@ class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline):
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||
height, width = downscale_height_and_width(height, width, self.movq_scale_factor)
|
||||
latents = self.prepare_latents(
|
||||
latents, latent_timestep, batch_size, num_images_per_prompt, image_embeds.dtype, device, generator
|
||||
)
|
||||
@@ -379,13 +363,7 @@ class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline):
|
||||
if output_type not in ["pt", "np", "pil"]:
|
||||
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
|
||||
|
||||
if output_type in ["np", "pil"]:
|
||||
image = image * 0.5 + 0.5
|
||||
image = image.clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
image = self.image_processor.postprocess(image, output_type)
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
@@ -14,11 +14,10 @@
|
||||
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...models import UNet2DConditionModel, VQModel
|
||||
from ...schedulers import DDPMScheduler
|
||||
from ...utils import deprecate, is_torch_xla_available, logging
|
||||
@@ -76,27 +75,6 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width
|
||||
def downscale_height_and_width(height, width, scale_factor=8):
|
||||
new_height = height // scale_factor**2
|
||||
if height % scale_factor**2 != 0:
|
||||
new_height += 1
|
||||
new_width = width // scale_factor**2
|
||||
if width % scale_factor**2 != 0:
|
||||
new_width += 1
|
||||
return new_height * scale_factor, new_width * scale_factor
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.prepare_image
|
||||
def prepare_image(pil_image, w=512, h=512):
|
||||
pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
|
||||
arr = np.array(pil_image.convert("RGB"))
|
||||
arr = arr.astype(np.float32) / 127.5 - 1
|
||||
arr = np.transpose(arr, [2, 0, 1])
|
||||
image = torch.from_numpy(arr).unsqueeze(0)
|
||||
return image
|
||||
|
||||
|
||||
class KandinskyV22Img2ImgPipeline(DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for image-to-image generation using Kandinsky
|
||||
@@ -129,7 +107,14 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline):
|
||||
scheduler=scheduler,
|
||||
movq=movq,
|
||||
)
|
||||
self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
|
||||
movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) if getattr(self, "movq", None) else 8
|
||||
movq_latent_channels = self.movq.config.latent_channels if getattr(self, "movq", None) else 4
|
||||
self.image_processor = VaeImageProcessor(
|
||||
vae_scale_factor=movq_scale_factor,
|
||||
vae_latent_channels=movq_latent_channels,
|
||||
resample="bicubic",
|
||||
reducing_gap=1,
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.KandinskyImg2ImgPipeline.get_timesteps
|
||||
def get_timesteps(self, num_inference_steps, strength, device):
|
||||
@@ -319,7 +304,7 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline):
|
||||
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
|
||||
)
|
||||
|
||||
image = torch.cat([prepare_image(i, width, height) for i in image], dim=0)
|
||||
image = torch.cat([self.image_processor.preprocess(i, width, height) for i in image], dim=0)
|
||||
image = image.to(dtype=image_embeds.dtype, device=device)
|
||||
|
||||
latents = self.movq.encode(image)["latents"]
|
||||
@@ -327,7 +312,6 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline):
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||
height, width = downscale_height_and_width(height, width, self.movq_scale_factor)
|
||||
latents = self.prepare_latents(
|
||||
latents, latent_timestep, batch_size, num_images_per_prompt, image_embeds.dtype, device, generator
|
||||
)
|
||||
@@ -383,21 +367,9 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline):
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
if output_type not in ["pt", "np", "pil", "latent"]:
|
||||
raise ValueError(
|
||||
f"Only the output types `pt`, `pil` ,`np` and `latent` are supported not output_type={output_type}"
|
||||
)
|
||||
|
||||
if not output_type == "latent":
|
||||
# post-processing
|
||||
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
|
||||
if output_type in ["np", "pil"]:
|
||||
image = image * 0.5 + 0.5
|
||||
image = image.clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
image = self.image_processor.postprocess(image, output_type)
|
||||
else:
|
||||
image = latents
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import inspect
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import PIL.Image
|
||||
import torch
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import StableDiffusionLoraLoaderMixin
|
||||
from ...models import Kandinsky3UNet, VQModel
|
||||
from ...schedulers import DDPMScheduler
|
||||
@@ -53,24 +53,6 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
def downscale_height_and_width(height, width, scale_factor=8):
|
||||
new_height = height // scale_factor**2
|
||||
if height % scale_factor**2 != 0:
|
||||
new_height += 1
|
||||
new_width = width // scale_factor**2
|
||||
if width % scale_factor**2 != 0:
|
||||
new_width += 1
|
||||
return new_height * scale_factor, new_width * scale_factor
|
||||
|
||||
|
||||
def prepare_image(pil_image):
|
||||
arr = np.array(pil_image.convert("RGB"))
|
||||
arr = arr.astype(np.float32) / 127.5 - 1
|
||||
arr = np.transpose(arr, [2, 0, 1])
|
||||
image = torch.from_numpy(arr).unsqueeze(0)
|
||||
return image
|
||||
|
||||
|
||||
class Kandinsky3Img2ImgPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin):
|
||||
model_cpu_offload_seq = "text_encoder->movq->unet->movq"
|
||||
_callback_tensor_inputs = [
|
||||
@@ -94,6 +76,14 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixi
|
||||
self.register_modules(
|
||||
tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, movq=movq
|
||||
)
|
||||
movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) if getattr(self, "movq", None) else 8
|
||||
movq_latent_channels = self.movq.config.latent_channels if getattr(self, "movq", None) else 4
|
||||
self.image_processor = VaeImageProcessor(
|
||||
vae_scale_factor=movq_scale_factor,
|
||||
vae_latent_channels=movq_latent_channels,
|
||||
resample="bicubic",
|
||||
reducing_gap=1,
|
||||
)
|
||||
|
||||
def get_timesteps(self, num_inference_steps, strength, device):
|
||||
# get the original timestep using init_timestep
|
||||
@@ -566,7 +556,7 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixi
|
||||
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
|
||||
)
|
||||
|
||||
image = torch.cat([prepare_image(i) for i in image], dim=0)
|
||||
image = torch.cat([self.image_processor.preprocess(i) for i in image], dim=0)
|
||||
image = image.to(dtype=prompt_embeds.dtype, device=device)
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
@@ -630,20 +620,9 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixi
|
||||
xm.mark_step()
|
||||
|
||||
# post-processing
|
||||
if output_type not in ["pt", "np", "pil", "latent"]:
|
||||
raise ValueError(
|
||||
f"Only the output types `pt`, `pil`, `np` and `latent` are supported not output_type={output_type}"
|
||||
)
|
||||
if not output_type == "latent":
|
||||
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
|
||||
|
||||
if output_type in ["np", "pil"]:
|
||||
image = image * 0.5 + 0.5
|
||||
image = image.clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
image = self.image_processor.postprocess(image, output_type)
|
||||
else:
|
||||
image = latents
|
||||
|
||||
|
||||
@@ -609,12 +609,6 @@ class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffu
|
||||
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
||||
)
|
||||
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
|
||||
|
||||
# Offload text encoder if `enable_model_cpu_offload` was enabled
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.text_encoder_2.to("cpu")
|
||||
@@ -628,6 +622,11 @@ class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffu
|
||||
init_latents = image
|
||||
|
||||
else:
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
if self.vae.config.force_upcast:
|
||||
image = image.float()
|
||||
|
||||
@@ -501,7 +501,7 @@ class LattePipeline(DiffusionPipeline):
|
||||
# &
|
||||
caption = re.sub(r"&", "", caption)
|
||||
|
||||
# ip adresses:
|
||||
# ip addresses:
|
||||
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
|
||||
|
||||
# article ids:
|
||||
|
||||
@@ -534,7 +534,7 @@ class LuminaPipeline(DiffusionPipeline):
|
||||
# &
|
||||
caption = re.sub(r"&", "", caption)
|
||||
|
||||
# ip adresses:
|
||||
# ip addresses:
|
||||
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
|
||||
|
||||
# article ids:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user