Compare commits
34 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| f9e27de31a | |||
| 05e867784d | |||
| d72184eba3 | |||
| 5ce4814af1 | |||
| 1bc6f3dc0f | |||
| 79bd7ecc78 | |||
| 9b834f8710 | |||
| 81426b0f19 | |||
| f0dba33d82 | |||
| d1db4f853a | |||
| 8adc6003ba | |||
| 9f91305f85 | |||
| 368958df6f | |||
| e52ceae375 | |||
| 62cbde8d41 | |||
| 648e8955cf | |||
| 00b179fb1a | |||
| 47ef79464f | |||
| b272807bc8 | |||
| 447ccd0679 | |||
| f3e09114f2 | |||
| 91545666e0 | |||
| b6f7933044 | |||
| 33e636cea5 | |||
| e27142ac64 | |||
| 8e88495da2 | |||
| b79803fe08 | |||
| b0f7036d9a | |||
| 6c7fad7ec8 | |||
| 5b0dab1253 | |||
| 7c6e9ef425 | |||
| f46abfe4ce | |||
| 73a9d5856f | |||
| 16c955c5fd |
@@ -14,4 +14,4 @@ jobs:
|
||||
with:
|
||||
python_quality_dependencies: "[quality]"
|
||||
secrets:
|
||||
bot_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
bot_token: ${{ secrets.HF_STYLE_BOT_ACTION }}
|
||||
@@ -180,6 +180,8 @@
|
||||
title: Caching
|
||||
- local: optimization/memory
|
||||
title: Reduce memory usage
|
||||
- local: optimization/pruna
|
||||
title: Pruna
|
||||
- local: optimization/xformers
|
||||
title: xFormers
|
||||
- local: optimization/tome
|
||||
@@ -283,6 +285,8 @@
|
||||
title: AllegroTransformer3DModel
|
||||
- local: api/models/aura_flow_transformer2d
|
||||
title: AuraFlowTransformer2DModel
|
||||
- local: api/models/chroma_transformer
|
||||
title: ChromaTransformer2DModel
|
||||
- local: api/models/cogvideox_transformer3d
|
||||
title: CogVideoXTransformer3DModel
|
||||
- local: api/models/cogview3plus_transformer2d
|
||||
@@ -405,6 +409,8 @@
|
||||
title: AutoPipeline
|
||||
- local: api/pipelines/blip_diffusion
|
||||
title: BLIP-Diffusion
|
||||
- local: api/pipelines/chroma
|
||||
title: Chroma
|
||||
- local: api/pipelines/cogvideox
|
||||
title: CogVideoX
|
||||
- local: api/pipelines/cogview3
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# ChromaTransformer2DModel
|
||||
|
||||
A modified flux Transformer model from [Chroma](https://huggingface.co/lodestones/Chroma)
|
||||
|
||||
## ChromaTransformer2DModel
|
||||
|
||||
[[autodoc]] ChromaTransformer2DModel
|
||||
@@ -0,0 +1,71 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Chroma
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
|
||||
</div>
|
||||
|
||||
Chroma is a text to image generation model based on Flux.
|
||||
|
||||
Original model checkpoints for Chroma can be found [here](https://huggingface.co/lodestones/Chroma).
|
||||
|
||||
<Tip>
|
||||
|
||||
Chroma can use all the same optimizations as Flux.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Inference (Single File)
|
||||
|
||||
The `ChromaTransformer2DModel` supports loading checkpoints in the original format. This is also useful when trying to load finetunes or quantized versions of the models that have been published by the community.
|
||||
|
||||
The following example demonstrates how to run Chroma from a single file.
|
||||
|
||||
Then run the following example
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import ChromaTransformer2DModel, ChromaPipeline
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
bfl_repo = "black-forest-labs/FLUX.1-dev"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v35.safetensors", torch_dtype=dtype)
|
||||
|
||||
text_encoder = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
|
||||
tokenizer = T5Tokenizer.from_pretrained(bfl_repo, subfolder="tokenizer_2", torch_dtype=dtype)
|
||||
|
||||
pipe = ChromaPipeline.from_pretrained(bfl_repo, transformer=transformer, text_encoder=text_encoder, tokenizer=tokenizer, torch_dtype=dtype)
|
||||
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = "A cat holding a sign that says hello world"
|
||||
image = pipe(
|
||||
prompt,
|
||||
guidance_scale=4.0,
|
||||
output_type="pil",
|
||||
num_inference_steps=26,
|
||||
generator=torch.Generator("cpu").manual_seed(0)
|
||||
).images[0]
|
||||
|
||||
image.save("image.png")
|
||||
```
|
||||
|
||||
## ChromaPipeline
|
||||
|
||||
[[autodoc]] ChromaPipeline
|
||||
- all
|
||||
- __call__
|
||||
@@ -36,6 +36,22 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## Cosmos2TextToImagePipeline
|
||||
|
||||
[[autodoc]] Cosmos2TextToImagePipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## Cosmos2VideoToWorldPipeline
|
||||
|
||||
[[autodoc]] Cosmos2VideoToWorldPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## CosmosPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.cosmos.pipeline_output.CosmosPipelineOutput
|
||||
|
||||
## CosmosImagePipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.cosmos.pipeline_output.CosmosImagePipelineOutput
|
||||
|
||||
@@ -22,17 +22,30 @@
|
||||
|
||||
# Wan2.1
|
||||
|
||||
[Wan2.1](https://files.alicdn.com/tpsservice/5c9de1c74de03972b7aa657e5a54756b.pdf) is a series of large diffusion transformer available in two versions, a high-performance 14B parameter model and a more accessible 1.3B version. Trained on billions of images and videos, it supports tasks like text-to-video (T2V) and image-to-video (I2V) while enabling features such as camera control and stylistic diversity. The Wan-VAE features better image data compression and a feature cache mechanism that encodes and decodes a video in chunks. To maintain continuity, features from previous chunks are cached and reused for processing subsequent chunks. This improves inference efficiency by reducing memory usage. Wan2.1 also uses a multilingual text encoder and the diffusion transformer models space and time relationships and text conditions with each time step to capture more complex video dynamics.
|
||||
[Wan-2.1](https://huggingface.co/papers/2503.20314) by the Wan Team.
|
||||
|
||||
*This report presents Wan, a comprehensive and open suite of video foundation models designed to push the boundaries of video generation. Built upon the mainstream diffusion transformer paradigm, Wan achieves significant advancements in generative capabilities through a series of innovations, including our novel VAE, scalable pre-training strategies, large-scale data curation, and automated evaluation metrics. These contributions collectively enhance the model's performance and versatility. Specifically, Wan is characterized by four key features: Leading Performance: The 14B model of Wan, trained on a vast dataset comprising billions of images and videos, demonstrates the scaling laws of video generation with respect to both data and model size. It consistently outperforms the existing open-source models as well as state-of-the-art commercial solutions across multiple internal and external benchmarks, demonstrating a clear and significant performance superiority. Comprehensiveness: Wan offers two capable models, i.e., 1.3B and 14B parameters, for efficiency and effectiveness respectively. It also covers multiple downstream applications, including image-to-video, instruction-guided video editing, and personal video generation, encompassing up to eight tasks. Consumer-Grade Efficiency: The 1.3B model demonstrates exceptional resource efficiency, requiring only 8.19 GB VRAM, making it compatible with a wide range of consumer-grade GPUs. Openness: We open-source the entire series of Wan, including source code and all models, with the goal of fostering the growth of the video generation community. This openness seeks to significantly expand the creative possibilities of video production in the industry and provide academia with high-quality video foundation models. All the code and models are available at [this https URL](https://github.com/Wan-Video/Wan2.1).*
|
||||
|
||||
You can find all the original Wan2.1 checkpoints under the [Wan-AI](https://huggingface.co/Wan-AI) organization.
|
||||
|
||||
The following Wan models are supported in Diffusers:
|
||||
- [Wan 2.1 T2V 1.3B](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B-Diffusers)
|
||||
- [Wan 2.1 T2V 14B](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B-Diffusers)
|
||||
- [Wan 2.1 I2V 14B - 480P](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P-Diffusers)
|
||||
- [Wan 2.1 I2V 14B - 720P](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P-Diffusers)
|
||||
- [Wan 2.1 FLF2V 14B - 720P](https://huggingface.co/Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers)
|
||||
- [Wan 2.1 VACE 1.3B](https://huggingface.co/Wan-AI/Wan2.1-VACE-1.3B-diffusers)
|
||||
- [Wan 2.1 VACE 14B](https://huggingface.co/Wan-AI/Wan2.1-VACE-14B-diffusers)
|
||||
|
||||
> [!TIP]
|
||||
> Click on the Wan2.1 models in the right sidebar for more examples of video generation.
|
||||
|
||||
### Text-to-Video Generation
|
||||
|
||||
The example below demonstrates how to generate a video from text optimized for memory or inference speed.
|
||||
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="memory">
|
||||
<hfoptions id="T2V usage">
|
||||
<hfoption id="T2V memory">
|
||||
|
||||
Refer to the [Reduce memory usage](../../optimization/memory) guide for more details about the various memory saving techniques.
|
||||
|
||||
@@ -100,7 +113,7 @@ export_to_video(output, "output.mp4", fps=16)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="inference speed">
|
||||
<hfoption id="T2V inference speed">
|
||||
|
||||
[Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster.
|
||||
|
||||
@@ -157,6 +170,81 @@ export_to_video(output, "output.mp4", fps=16)
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### First-Last-Frame-to-Video Generation
|
||||
|
||||
The example below demonstrates how to use the image-to-video pipeline to generate a video using a text description, a starting frame, and an ending frame.
|
||||
|
||||
<hfoptions id="FLF2V usage">
|
||||
<hfoption id="usage">
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as TF
|
||||
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
|
||||
from diffusers.utils import export_to_video, load_image
|
||||
from transformers import CLIPVisionModel
|
||||
|
||||
|
||||
model_id = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
|
||||
image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
|
||||
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
pipe = WanImageToVideoPipeline.from_pretrained(
|
||||
model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.to("cuda")
|
||||
|
||||
first_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png")
|
||||
last_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png")
|
||||
|
||||
def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
|
||||
aspect_ratio = image.height / image.width
|
||||
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
|
||||
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
|
||||
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
|
||||
image = image.resize((width, height))
|
||||
return image, height, width
|
||||
|
||||
def center_crop_resize(image, height, width):
|
||||
# Calculate resize ratio to match first frame dimensions
|
||||
resize_ratio = max(width / image.width, height / image.height)
|
||||
|
||||
# Resize the image
|
||||
width = round(image.width * resize_ratio)
|
||||
height = round(image.height * resize_ratio)
|
||||
size = [width, height]
|
||||
image = TF.center_crop(image, size)
|
||||
|
||||
return image, height, width
|
||||
|
||||
first_frame, height, width = aspect_ratio_resize(first_frame, pipe)
|
||||
if last_frame.size != first_frame.size:
|
||||
last_frame, _, _ = center_crop_resize(last_frame, height, width)
|
||||
|
||||
prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
|
||||
|
||||
output = pipe(
|
||||
image=first_frame, last_image=last_frame, prompt=prompt, height=height, width=width, guidance_scale=5.5
|
||||
).frames[0]
|
||||
export_to_video(output, "output.mp4", fps=16)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Any-to-Video Controllable Generation
|
||||
|
||||
Wan VACE supports various generation techniques which achieve controllable video generation. Some of the capabilities include:
|
||||
- Control to Video (Depth, Pose, Sketch, Flow, Grayscale, Scribble, Layout, Boundary Box, etc.). Recommended library for preprocessing videos to obtain control videos: [huggingface/controlnet_aux]()
|
||||
- Image/Video to Video (first frame, last frame, starting clip, ending clip, random clips)
|
||||
- Inpainting and Outpainting
|
||||
- Subject to Video (faces, object, characters, etc.)
|
||||
- Composition to Video (reference anything, animate anything, swap anything, expand anything, move anything, etc.)
|
||||
|
||||
The code snippets available in [this](https://github.com/huggingface/diffusers/pull/11582) pull request demonstrate some examples of how videos can be generated with controllability signals.
|
||||
|
||||
The general rule of thumb to keep in mind when preparing inputs for the VACE pipeline is that the input images, or frames of a video that you want to use for conditioning, should have a corresponding mask that is black in color. The black mask signifies that the model will not generate new content for that area, and only use those parts for conditioning the generation process. For parts/frames that should be generated by the model, the mask should be white in color.
|
||||
|
||||
## Notes
|
||||
|
||||
- Wan2.1 supports LoRAs with [`~loaders.WanLoraLoaderMixin.load_lora_weights`].
|
||||
@@ -251,6 +339,18 @@ export_to_video(output, "output.mp4", fps=16)
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## WanVACEPipeline
|
||||
|
||||
[[autodoc]] WanVACEPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## WanVideoToVideoPipeline
|
||||
|
||||
[[autodoc]] WanVideoToVideoPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## WanPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.wan.pipeline_output.WanPipelineOutput
|
||||
@@ -0,0 +1,187 @@
|
||||
# Pruna
|
||||
|
||||
[Pruna](https://github.com/PrunaAI/pruna) is a model optimization framework that offers various optimization methods - quantization, pruning, caching, compilation - for accelerating inference and reducing memory usage. A general overview of the optimization methods are shown below.
|
||||
|
||||
|
||||
| Technique | Description | Speed | Memory | Quality |
|
||||
|--------------|-----------------------------------------------------------------------------------------------|:-----:|:------:|:-------:|
|
||||
| `batcher` | Groups multiple inputs together to be processed simultaneously, improving computational efficiency and reducing processing time. | ✅ | ❌ | ➖ |
|
||||
| `cacher` | Stores intermediate results of computations to speed up subsequent operations. | ✅ | ➖ | ➖ |
|
||||
| `compiler` | Optimises the model with instructions for specific hardware. | ✅ | ➖ | ➖ |
|
||||
| `distiller` | Trains a smaller, simpler model to mimic a larger, more complex model. | ✅ | ✅ | ❌ |
|
||||
| `quantizer` | Reduces the precision of weights and activations, lowering memory requirements. | ✅ | ✅ | ❌ |
|
||||
| `pruner` | Removes less important or redundant connections and neurons, resulting in a sparser, more efficient network. | ✅ | ✅ | ❌ |
|
||||
| `recoverer` | Restores the performance of a model after compression. | ➖ | ➖ | ✅ |
|
||||
| `factorizer` | Factorization batches several small matrix multiplications into one large fused operation. | ✅ | ➖ | ➖ |
|
||||
| `enhancer` | Enhances the model output by applying post-processing algorithms such as denoising or upscaling. | ❌ | - | ✅ |
|
||||
|
||||
✅ (improves), ➖ (approx. the same), ❌ (worsens)
|
||||
|
||||
Explore the full range of optimization methods in the [Pruna documentation](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/configure.html#configure-algorithms).
|
||||
|
||||
## Installation
|
||||
|
||||
Install Pruna with the following command.
|
||||
|
||||
```bash
|
||||
pip install pruna
|
||||
```
|
||||
|
||||
|
||||
## Optimize Diffusers models
|
||||
|
||||
A broad range of optimization algorithms are supported for Diffusers models as shown below.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/PrunaAI/documentation-images/resolve/main/diffusers/diffusers_combinations.png" alt="Overview of the supported optimization algorithms for diffusers models">
|
||||
</div>
|
||||
|
||||
The example below optimizes [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)
|
||||
with a combination of factorizer, compiler, and cacher algorithms. This combination accelerates inference by up to 4.2x and cuts peak GPU memory usage from 34.7GB to 28.0GB, all while maintaining virtually the same output quality.
|
||||
|
||||
> [!TIP]
|
||||
> Refer to the [Pruna optimization](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/configure.html) docs to learn more about the optimization techniques used in this example.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/PrunaAI/documentation-images/resolve/main/diffusers/flux_combination.png" alt="Optimization techniques used for FLUX.1-dev showing the combination of factorizer, compiler, and cacher algorithms">
|
||||
</div>
|
||||
|
||||
Start by defining a `SmashConfig` with the optimization algorithms to use. To optimize the model, wrap the pipeline and the `SmashConfig` with `smash` and then use the pipeline as normal for inference.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxPipeline
|
||||
|
||||
from pruna import PrunaModel, SmashConfig, smash
|
||||
|
||||
# load the model
|
||||
# Try segmind/Segmind-Vega or black-forest-labs/FLUX.1-schnell with a small GPU memory
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
|
||||
# define the configuration
|
||||
smash_config = SmashConfig()
|
||||
smash_config["factorizer"] = "qkv_diffusers"
|
||||
smash_config["compiler"] = "torch_compile"
|
||||
smash_config["torch_compile_target"] = "module_list"
|
||||
smash_config["cacher"] = "fora"
|
||||
smash_config["fora_interval"] = 2
|
||||
|
||||
# for the best results in terms of speed you can add these configs
|
||||
# however they will increase your warmup time from 1.5 min to 10 min
|
||||
# smash_config["torch_compile_mode"] = "max-autotune-no-cudagraphs"
|
||||
# smash_config["quantizer"] = "torchao"
|
||||
# smash_config["torchao_quant_type"] = "fp8dq"
|
||||
# smash_config["torchao_excluded_modules"] = "norm+embedding"
|
||||
|
||||
# optimize the model
|
||||
smashed_pipe = smash(pipe, smash_config)
|
||||
|
||||
# run the model
|
||||
smashed_pipe("a knitted purple prune").images[0]
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/PrunaAI/documentation-images/resolve/main/diffusers/flux_smashed_comparison.png">
|
||||
</div>
|
||||
|
||||
After optimization, we can share and load the optimized model using the Hugging Face Hub.
|
||||
|
||||
```python
|
||||
# save the model
|
||||
smashed_pipe.save_to_hub("<username>/FLUX.1-dev-smashed")
|
||||
|
||||
# load the model
|
||||
smashed_pipe = PrunaModel.from_hub("<username>/FLUX.1-dev-smashed")
|
||||
```
|
||||
|
||||
## Evaluate and benchmark Diffusers models
|
||||
|
||||
Pruna provides the [EvaluationAgent](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/evaluate.html) to evaluate the quality of your optimized models.
|
||||
|
||||
We can metrics we care about, such as total time and throughput, and the dataset to evaluate on. We can define a model and pass it to the `EvaluationAgent`.
|
||||
|
||||
<hfoptions id="eval">
|
||||
<hfoption id="optimized model">
|
||||
|
||||
We can load and evaluate an optimized model by using the `EvaluationAgent` and pass it to the `Task`.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxPipeline
|
||||
|
||||
from pruna import PrunaModel
|
||||
from pruna.data.pruna_datamodule import PrunaDataModule
|
||||
from pruna.evaluation.evaluation_agent import EvaluationAgent
|
||||
from pruna.evaluation.metrics import (
|
||||
ThroughputMetric,
|
||||
TorchMetricWrapper,
|
||||
TotalTimeMetric,
|
||||
)
|
||||
from pruna.evaluation.task import Task
|
||||
|
||||
# define the device
|
||||
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
|
||||
# load the model
|
||||
# Try PrunaAI/Segmind-Vega-smashed or PrunaAI/FLUX.1-dev-smashed with a small GPU memory
|
||||
smashed_pipe = PrunaModel.from_hub("PrunaAI/FLUX.1-dev-smashed")
|
||||
|
||||
# Define the metrics
|
||||
metrics = [
|
||||
TotalTimeMetric(n_iterations=20, n_warmup_iterations=5),
|
||||
ThroughputMetric(n_iterations=20, n_warmup_iterations=5),
|
||||
TorchMetricWrapper("clip"),
|
||||
]
|
||||
|
||||
# Define the datamodule
|
||||
datamodule = PrunaDataModule.from_string("LAION256")
|
||||
datamodule.limit_datasets(10)
|
||||
|
||||
# Define the task and evaluation agent
|
||||
task = Task(metrics, datamodule=datamodule, device=device)
|
||||
eval_agent = EvaluationAgent(task)
|
||||
|
||||
# Evaluate smashed model and offload it to CPU
|
||||
smashed_pipe.move_to_device(device)
|
||||
smashed_pipe_results = eval_agent.evaluate(smashed_pipe)
|
||||
smashed_pipe.move_to_device("cpu")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="standalone model">
|
||||
|
||||
Instead of comparing the optimized model to the base model, you can also evaluate the standalone `diffusers` model. This is useful if you want to evaluate the performance of the model without the optimization. We can do so by using the `PrunaModel` wrapper and run the `EvaluationAgent` on it.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxPipeline
|
||||
|
||||
from pruna import PrunaModel
|
||||
|
||||
# load the model
|
||||
# Try PrunaAI/Segmind-Vega-smashed or PrunaAI/FLUX.1-dev-smashed with a small GPU memory
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
torch_dtype=torch.bfloat16
|
||||
).to("cpu")
|
||||
wrapped_pipe = PrunaModel(model=pipe)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Now that you have seen how to optimize and evaluate your models, you can start using Pruna to optimize your own models. Luckily, we have many examples to help you get started.
|
||||
|
||||
> [!TIP]
|
||||
> For more details about benchmarking Flux, check out the [Announcing FLUX-Juiced: The Fastest Image Generation Endpoint (2.6 times faster)!](https://huggingface.co/blog/PrunaAI/flux-fastest-image-generation-endpoint) blog post and the [InferBench](https://huggingface.co/spaces/PrunaAI/InferBench) Space.
|
||||
|
||||
## Reference
|
||||
|
||||
- [Pruna](https://github.com/pruna-ai/pruna)
|
||||
- [Pruna optimization](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/configure.html#configure-algorithms)
|
||||
- [Pruna evaluation](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/evaluate.html)
|
||||
- [Pruna tutorials](https://docs.pruna.ai/en/stable/docs_pruna/tutorials/index.html)
|
||||
|
||||
@@ -416,6 +416,45 @@ text_encoder_2_4bit.dequantize()
|
||||
transformer_4bit.dequantize()
|
||||
```
|
||||
|
||||
## torch.compile
|
||||
|
||||
Speed up inference with `torch.compile`. Make sure you have the latest `bitsandbytes` installed and we also recommend installing [PyTorch nightly](https://pytorch.org/get-started/locally/).
|
||||
|
||||
<hfoptions id="bnb">
|
||||
<hfoption id="8-bit">
|
||||
```py
|
||||
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
|
||||
transformer_4bit = AutoModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
transformer_4bit.compile(fullgraph=True)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="4-bit">
|
||||
|
||||
```py
|
||||
quant_config = DiffusersBitsAndBytesConfig(load_in_4bit=True)
|
||||
transformer_4bit = AutoModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
transformer_4bit.compile(fullgraph=True)
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
On an RTX 4090 with compilation, 4-bit Flux generation completed in 25.809 seconds versus 32.570 seconds without.
|
||||
|
||||
Check out the [benchmarking script](https://gist.github.com/sayakpaul/0db9d8eeeb3d2a0e5ed7cf0d9ca19b7d) for more details.
|
||||
|
||||
## Resources
|
||||
|
||||
* [End-to-end notebook showing Flux.1 Dev inference in a free-tier Colab](https://gist.github.com/sayakpaul/c76bd845b48759e11687ac550b99d8b4)
|
||||
|
||||
@@ -65,6 +65,9 @@ transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
|
||||
|
||||
For speed and memory benchmarks on Flux and CogVideoX, please refer to the table [here](https://github.com/huggingface/diffusers/pull/10009#issue-2688781450). You can also find some torchao [benchmarks](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks) numbers for various hardware.
|
||||
|
||||
> [!TIP]
|
||||
> The FP8 post-training quantization schemes in torchao are effective for GPUs with compute capability of at least 8.9 (RTX-4090, Hopper, etc.). FP8 often provides the best speed, memory, and quality trade-off when generating images and videos. We recommend combining FP8 and torch.compile if your GPU is compatible.
|
||||
|
||||
torchao also supports an automatic quantization API through [autoquant](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#autoquantization). Autoquantization determines the best quantization strategy applicable to a model by comparing the performance of each technique on chosen input types and shapes. Currently, this can be used directly on the underlying modeling components. Diffusers will also expose an autoquant configuration option in the future.
|
||||
|
||||
The `TorchAoConfig` class accepts three parameters:
|
||||
|
||||
@@ -76,6 +76,24 @@ This command will prompt you for a token. Copy-paste yours from your [settings/t
|
||||
> `pip install wandb`
|
||||
> Alternatively, you can use other tools / train without reporting by modifying the flag `--report_to="wandb"`.
|
||||
|
||||
### LoRA Rank and Alpha
|
||||
Two key LoRA hyperparameters are LoRA rank and LoRA alpha.
|
||||
- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters).
|
||||
- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by lora_alpha / lora_rank.
|
||||
- lora_alpha vs. rank:
|
||||
This ratio dictates the LoRA's effective strength:
|
||||
lora_alpha == rank: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16)
|
||||
lora_alpha < rank: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16)
|
||||
lora_alpha > rank: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16)
|
||||
|
||||
> [!TIP]
|
||||
> A common starting point is to set `lora_alpha` equal to `rank`.
|
||||
> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16)
|
||||
> to give the LoRA updates more influence without increasing parameter count.
|
||||
> If you find your LoRA is "overcooking" or learning too aggressively, consider setting `lora_alpha` to half of `rank`
|
||||
> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case.
|
||||
|
||||
|
||||
### Target Modules
|
||||
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
|
||||
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
@@ -20,6 +21,8 @@ import tempfile
|
||||
|
||||
import safetensors
|
||||
|
||||
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
|
||||
|
||||
|
||||
sys.path.append("..")
|
||||
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
|
||||
@@ -281,3 +284,45 @@ class DreamBoothLoRAFluxAdvanced(ExamplesTestsAccelerate):
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
|
||||
|
||||
def test_dreambooth_lora_with_metadata(self):
|
||||
# Use a `lora_alpha` that is different from `rank`.
|
||||
lora_alpha = 8
|
||||
rank = 4
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--lora_alpha={lora_alpha}
|
||||
--rank={rank}
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
|
||||
self.assertTrue(os.path.isfile(state_dict_file))
|
||||
|
||||
# Check if the metadata was properly serialized.
|
||||
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
|
||||
metadata = f.metadata() or {}
|
||||
|
||||
metadata.pop("format", None)
|
||||
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
|
||||
if raw:
|
||||
raw = json.loads(raw)
|
||||
|
||||
loaded_lora_alpha = raw["transformer.lora_alpha"]
|
||||
self.assertTrue(loaded_lora_alpha == lora_alpha)
|
||||
loaded_lora_rank = raw["transformer.r"]
|
||||
self.assertTrue(loaded_lora_rank == rank)
|
||||
|
||||
@@ -55,6 +55,7 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import (
|
||||
_collate_lora_metadata,
|
||||
_set_state_dict_into_text_encoder,
|
||||
cast_training_params,
|
||||
compute_density_for_timestep_sampling,
|
||||
@@ -431,6 +432,13 @@ def parse_args(input_args=None):
|
||||
help=("The dimension of the LoRA update matrices."),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lora_alpha",
|
||||
type=int,
|
||||
default=4,
|
||||
help="LoRA alpha to be used for additional scaling.",
|
||||
)
|
||||
|
||||
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
|
||||
|
||||
parser.add_argument(
|
||||
@@ -1556,7 +1564,7 @@ def main(args):
|
||||
# now we will add new LoRA weights to the attention layers
|
||||
transformer_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
lora_alpha=args.lora_alpha,
|
||||
lora_dropout=args.lora_dropout,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=target_modules,
|
||||
@@ -1565,7 +1573,7 @@ def main(args):
|
||||
if args.train_text_encoder:
|
||||
text_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
lora_alpha=args.lora_alpha,
|
||||
lora_dropout=args.lora_dropout,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
@@ -1582,13 +1590,15 @@ def main(args):
|
||||
if accelerator.is_main_process:
|
||||
transformer_lora_layers_to_save = None
|
||||
text_encoder_one_lora_layers_to_save = None
|
||||
|
||||
modules_to_save = {}
|
||||
for model in models:
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
modules_to_save["transformer"] = model
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_one))):
|
||||
if args.train_text_encoder: # when --train_text_encoder_ti we don't save the layers
|
||||
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
modules_to_save["text_encoder"] = model
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_two))):
|
||||
pass # when --train_text_encoder_ti and --enable_t5_ti we don't save the layers
|
||||
else:
|
||||
@@ -1601,6 +1611,7 @@ def main(args):
|
||||
output_dir,
|
||||
transformer_lora_layers=transformer_lora_layers_to_save,
|
||||
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
|
||||
**_collate_lora_metadata(modules_to_save),
|
||||
)
|
||||
if args.train_text_encoder_ti:
|
||||
embedding_handler.save_embeddings(f"{args.output_dir}/{Path(args.output_dir).name}_emb.safetensors")
|
||||
@@ -2359,16 +2370,19 @@ def main(args):
|
||||
# Save the lora layers
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
modules_to_save = {}
|
||||
transformer = unwrap_model(transformer)
|
||||
if args.upcast_before_saving:
|
||||
transformer.to(torch.float32)
|
||||
else:
|
||||
transformer = transformer.to(weight_dtype)
|
||||
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
||||
modules_to_save["transformer"] = transformer
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one = unwrap_model(text_encoder_one)
|
||||
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
|
||||
modules_to_save["text_encoder"] = text_encoder_one
|
||||
else:
|
||||
text_encoder_lora_layers = None
|
||||
|
||||
@@ -2377,6 +2391,7 @@ def main(args):
|
||||
save_directory=args.output_dir,
|
||||
transformer_lora_layers=transformer_lora_layers,
|
||||
text_encoder_lora_layers=text_encoder_lora_layers,
|
||||
**_collate_lora_metadata(modules_to_save),
|
||||
)
|
||||
|
||||
if args.train_text_encoder_ti:
|
||||
|
||||
@@ -282,10 +282,7 @@ class IPAdapterFaceIDStableDiffusionPipeline(
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name,
|
||||
|
||||
@@ -0,0 +1,205 @@
|
||||
# Copyright Philip Brown, ppbrown@github
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
###########################################################################
|
||||
# This pipeline attempts to use a model that has SDXL vae, T5 text encoder,
|
||||
# and SDXL unet.
|
||||
# At the present time, there are no pretrained models that give pleasing
|
||||
# output. So as yet, (2025/06/10) this pipeline is somewhat of a tech
|
||||
# demo proving that the pieces can at least be put together.
|
||||
# Hopefully, it will encourage someone with the hardware available to
|
||||
# throw enough resources into training one up.
|
||||
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
CLIPTokenizer,
|
||||
CLIPVisionModelWithProjection,
|
||||
T5EncoderModel,
|
||||
)
|
||||
|
||||
from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
|
||||
|
||||
# Note: At this time, the intent is to use the T5 encoder mentioned
|
||||
# below, with zero changes.
|
||||
# Therefore, the model deliberately does not store the T5 encoder model bytes,
|
||||
# (Since they are not unique!)
|
||||
# but instead takes advantage of huggingface hub cache loading
|
||||
|
||||
T5_NAME = "mcmonkey/google_t5-v1_1-xxl_encoderonly"
|
||||
|
||||
# Caller is expected to load this, or equivalent, as model name for now
|
||||
# eg: pipe = StableDiffusionXL_T5Pipeline(SDXL_NAME)
|
||||
SDXL_NAME = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
|
||||
|
||||
class LinearWithDtype(nn.Linear):
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.weight.dtype
|
||||
|
||||
|
||||
class StableDiffusionXL_T5Pipeline(StableDiffusionXLPipeline):
|
||||
_expected_modules = [
|
||||
"vae",
|
||||
"unet",
|
||||
"scheduler",
|
||||
"tokenizer",
|
||||
"image_encoder",
|
||||
"feature_extractor",
|
||||
"t5_encoder",
|
||||
"t5_projection",
|
||||
"t5_pooled_projection",
|
||||
]
|
||||
|
||||
_optional_components = [
|
||||
"image_encoder",
|
||||
"feature_extractor",
|
||||
"t5_encoder",
|
||||
"t5_projection",
|
||||
"t5_pooled_projection",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
tokenizer: CLIPTokenizer,
|
||||
t5_encoder=None,
|
||||
t5_projection=None,
|
||||
t5_pooled_projection=None,
|
||||
image_encoder: CLIPVisionModelWithProjection = None,
|
||||
feature_extractor: CLIPImageProcessor = None,
|
||||
force_zeros_for_empty_prompt: bool = True,
|
||||
add_watermarker: Optional[bool] = None,
|
||||
):
|
||||
DiffusionPipeline.__init__(self)
|
||||
|
||||
if t5_encoder is None:
|
||||
self.t5_encoder = T5EncoderModel.from_pretrained(T5_NAME, torch_dtype=unet.dtype)
|
||||
else:
|
||||
self.t5_encoder = t5_encoder
|
||||
|
||||
# ----- build T5 4096 => 2048 dim projection -----
|
||||
if t5_projection is None:
|
||||
self.t5_projection = LinearWithDtype(4096, 2048) # trainable
|
||||
else:
|
||||
self.t5_projection = t5_projection
|
||||
self.t5_projection.to(dtype=unet.dtype)
|
||||
# ----- build T5 4096 => 1280 dim projection -----
|
||||
if t5_pooled_projection is None:
|
||||
self.t5_pooled_projection = LinearWithDtype(4096, 1280) # trainable
|
||||
else:
|
||||
self.t5_pooled_projection = t5_pooled_projection
|
||||
self.t5_pooled_projection.to(dtype=unet.dtype)
|
||||
|
||||
print("dtype of Linear is ", self.t5_projection.dtype)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
tokenizer=tokenizer,
|
||||
t5_encoder=self.t5_encoder,
|
||||
t5_projection=self.t5_projection,
|
||||
t5_pooled_projection=self.t5_pooled_projection,
|
||||
image_encoder=image_encoder,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
self.default_sample_size = (
|
||||
self.unet.config.sample_size
|
||||
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
|
||||
else 128
|
||||
)
|
||||
|
||||
self.watermark = None
|
||||
|
||||
# Parts of original SDXL class complain if these attributes are not
|
||||
# at least PRESENT
|
||||
self.text_encoder = self.text_encoder_2 = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Encode a text prompt (T5-XXL + 4096→2048 projection)
|
||||
# Returns exactly four tensors in the order SDXL’s __call__ expects.
|
||||
# ------------------------------------------------------------------
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
num_images_per_prompt: int = 1,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: str | None = None,
|
||||
**_,
|
||||
):
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
prompt_embeds : Tensor [B, T, 2048]
|
||||
negative_prompt_embeds : Tensor [B, T, 2048] | None
|
||||
pooled_prompt_embeds : Tensor [B, 1280]
|
||||
negative_pooled_prompt_embeds: Tensor [B, 1280] | None
|
||||
where B = batch * num_images_per_prompt
|
||||
"""
|
||||
|
||||
# --- helper to tokenize on the pipeline’s device ----------------
|
||||
def _tok(text: str):
|
||||
tok_out = self.tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
).to(self.device)
|
||||
return tok_out.input_ids, tok_out.attention_mask
|
||||
|
||||
# ---------- positive stream -------------------------------------
|
||||
ids, mask = _tok(prompt)
|
||||
h_pos = self.t5_encoder(ids, attention_mask=mask).last_hidden_state # [b, T, 4096]
|
||||
tok_pos = self.t5_projection(h_pos) # [b, T, 2048]
|
||||
pool_pos = self.t5_pooled_projection(h_pos.mean(dim=1)) # [b, 1280]
|
||||
|
||||
# expand for multiple images per prompt
|
||||
tok_pos = tok_pos.repeat_interleave(num_images_per_prompt, 0)
|
||||
pool_pos = pool_pos.repeat_interleave(num_images_per_prompt, 0)
|
||||
|
||||
# ---------- negative / CFG stream --------------------------------
|
||||
if do_classifier_free_guidance:
|
||||
neg_text = "" if negative_prompt is None else negative_prompt
|
||||
ids_n, mask_n = _tok(neg_text)
|
||||
h_neg = self.t5_encoder(ids_n, attention_mask=mask_n).last_hidden_state
|
||||
tok_neg = self.t5_projection(h_neg)
|
||||
pool_neg = self.t5_pooled_projection(h_neg.mean(dim=1))
|
||||
|
||||
tok_neg = tok_neg.repeat_interleave(num_images_per_prompt, 0)
|
||||
pool_neg = pool_neg.repeat_interleave(num_images_per_prompt, 0)
|
||||
else:
|
||||
tok_neg = pool_neg = None
|
||||
|
||||
# ----------------- final ordered return --------------------------
|
||||
# 1) positive token embeddings
|
||||
# 2) negative token embeddings (or None)
|
||||
# 3) positive pooled embeddings
|
||||
# 4) negative pooled embeddings (or None)
|
||||
return tok_pos, tok_neg, pool_pos, pool_neg
|
||||
@@ -170,6 +170,23 @@ accelerate launch train_dreambooth_lora_flux.py \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
### LoRA Rank and Alpha
|
||||
Two key LoRA hyperparameters are LoRA rank and LoRA alpha.
|
||||
- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters).
|
||||
- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by lora_alpha / lora_rank.
|
||||
- lora_alpha vs. rank:
|
||||
This ratio dictates the LoRA's effective strength:
|
||||
lora_alpha == rank: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16)
|
||||
lora_alpha < rank: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16)
|
||||
lora_alpha > rank: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16)
|
||||
|
||||
> [!TIP]
|
||||
> A common starting point is to set `lora_alpha` equal to `rank`.
|
||||
> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16)
|
||||
> to give the LoRA updates more influence without increasing parameter count.
|
||||
> If you find your LoRA is "overcooking" or learning too aggressively, consider setting `lora_alpha` to half of `rank`
|
||||
> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case.
|
||||
|
||||
### Target Modules
|
||||
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
|
||||
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
@@ -20,6 +21,8 @@ import tempfile
|
||||
|
||||
import safetensors
|
||||
|
||||
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
|
||||
|
||||
|
||||
sys.path.append("..")
|
||||
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
|
||||
@@ -234,3 +237,45 @@ class DreamBoothLoRAFlux(ExamplesTestsAccelerate):
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
|
||||
|
||||
def test_dreambooth_lora_with_metadata(self):
|
||||
# Use a `lora_alpha` that is different from `rank`.
|
||||
lora_alpha = 8
|
||||
rank = 4
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--lora_alpha={lora_alpha}
|
||||
--rank={rank}
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
|
||||
self.assertTrue(os.path.isfile(state_dict_file))
|
||||
|
||||
# Check if the metadata was properly serialized.
|
||||
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
|
||||
metadata = f.metadata() or {}
|
||||
|
||||
metadata.pop("format", None)
|
||||
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
|
||||
if raw:
|
||||
raw = json.loads(raw)
|
||||
|
||||
loaded_lora_alpha = raw["transformer.lora_alpha"]
|
||||
self.assertTrue(loaded_lora_alpha == lora_alpha)
|
||||
loaded_lora_rank = raw["transformer.r"]
|
||||
self.assertTrue(loaded_lora_rank == rank)
|
||||
|
||||
@@ -27,7 +27,6 @@ from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
@@ -53,6 +52,7 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import (
|
||||
_collate_lora_metadata,
|
||||
_set_state_dict_into_text_encoder,
|
||||
cast_training_params,
|
||||
compute_density_for_timestep_sampling,
|
||||
@@ -358,7 +358,12 @@ def parse_args(input_args=None):
|
||||
default=4,
|
||||
help=("The dimension of the LoRA update matrices."),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lora_alpha",
|
||||
type=int,
|
||||
default=4,
|
||||
help="LoRA alpha to be used for additional scaling.",
|
||||
)
|
||||
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
|
||||
|
||||
parser.add_argument(
|
||||
@@ -1238,7 +1243,7 @@ def main(args):
|
||||
# now we will add new LoRA weights the transformer layers
|
||||
transformer_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
lora_alpha=args.lora_alpha,
|
||||
lora_dropout=args.lora_dropout,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=target_modules,
|
||||
@@ -1247,7 +1252,7 @@ def main(args):
|
||||
if args.train_text_encoder:
|
||||
text_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
lora_alpha=args.lora_alpha,
|
||||
lora_dropout=args.lora_dropout,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
@@ -1264,12 +1269,14 @@ def main(args):
|
||||
if accelerator.is_main_process:
|
||||
transformer_lora_layers_to_save = None
|
||||
text_encoder_one_lora_layers_to_save = None
|
||||
|
||||
modules_to_save = {}
|
||||
for model in models:
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
modules_to_save["transformer"] = model
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
modules_to_save["text_encoder"] = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
@@ -1280,6 +1287,7 @@ def main(args):
|
||||
output_dir,
|
||||
transformer_lora_layers=transformer_lora_layers_to_save,
|
||||
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
|
||||
**_collate_lora_metadata(modules_to_save),
|
||||
)
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
@@ -1889,16 +1897,19 @@ def main(args):
|
||||
# Save the lora layers
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
modules_to_save = {}
|
||||
transformer = unwrap_model(transformer)
|
||||
if args.upcast_before_saving:
|
||||
transformer.to(torch.float32)
|
||||
else:
|
||||
transformer = transformer.to(weight_dtype)
|
||||
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
||||
modules_to_save["transformer"] = transformer
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one = unwrap_model(text_encoder_one)
|
||||
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
|
||||
modules_to_save["text_encoder"] = text_encoder_one
|
||||
else:
|
||||
text_encoder_lora_layers = None
|
||||
|
||||
@@ -1906,6 +1917,7 @@ def main(args):
|
||||
save_directory=args.output_dir,
|
||||
transformer_lora_layers=transformer_lora_layers,
|
||||
text_encoder_lora_layers=text_encoder_lora_layers,
|
||||
**_collate_lora_metadata(modules_to_save),
|
||||
)
|
||||
|
||||
# Final inference
|
||||
|
||||
@@ -29,7 +29,7 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
import torch
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate import Accelerator, DistributedType
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
@@ -1181,13 +1181,15 @@ def main(args):
|
||||
transformer_lora_layers_to_save = None
|
||||
|
||||
for model in models:
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
|
||||
model = unwrap_model(model)
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
if weights:
|
||||
weights.pop()
|
||||
|
||||
HiDreamImagePipeline.save_lora_weights(
|
||||
output_dir,
|
||||
@@ -1197,13 +1199,20 @@ def main(args):
|
||||
def load_model_hook(models, input_dir):
|
||||
transformer_ = None
|
||||
|
||||
while len(models) > 0:
|
||||
model = models.pop()
|
||||
if not accelerator.distributed_type == DistributedType.DEEPSPEED:
|
||||
while len(models) > 0:
|
||||
model = models.pop()
|
||||
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_ = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
|
||||
model = unwrap_model(model)
|
||||
transformer_ = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
else:
|
||||
transformer_ = HiDreamImageTransformer2DModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="transformer"
|
||||
)
|
||||
transformer_.add_adapter(transformer_lora_config)
|
||||
|
||||
lora_state_dict = HiDreamImagePipeline.lora_state_dict(input_dir)
|
||||
|
||||
@@ -1655,7 +1664,7 @@ def main(args):
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
if accelerator.is_main_process:
|
||||
if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
||||
if args.checkpoints_total_limit is not None:
|
||||
|
||||
@@ -7,7 +7,17 @@ from accelerate import init_empty_weights
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from diffusers import AutoencoderKLCosmos, CosmosTextToWorldPipeline, CosmosTransformer3DModel, EDMEulerScheduler
|
||||
from diffusers import (
|
||||
AutoencoderKLCosmos,
|
||||
AutoencoderKLWan,
|
||||
Cosmos2TextToImagePipeline,
|
||||
Cosmos2VideoToWorldPipeline,
|
||||
CosmosTextToWorldPipeline,
|
||||
CosmosTransformer3DModel,
|
||||
CosmosVideoToWorldPipeline,
|
||||
EDMEulerScheduler,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
)
|
||||
|
||||
|
||||
def remove_keys_(key: str, state_dict: Dict[str, Any]):
|
||||
@@ -29,7 +39,7 @@ def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
|
||||
state_dict[new_key] = state_dict.pop(key)
|
||||
|
||||
|
||||
TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0 = {
|
||||
"t_embedder.1": "time_embed.t_embedder",
|
||||
"affline_norm": "time_embed.norm",
|
||||
".blocks.0.block.attn": ".attn1",
|
||||
@@ -56,7 +66,7 @@ TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"final_layer.linear": "proj_out",
|
||||
}
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0 = {
|
||||
"blocks.block": rename_transformer_blocks_,
|
||||
"logvar.0.freqs": remove_keys_,
|
||||
"logvar.0.phases": remove_keys_,
|
||||
@@ -64,6 +74,45 @@ TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"pos_embedder.seq": remove_keys_,
|
||||
}
|
||||
|
||||
TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 = {
|
||||
"t_embedder.1": "time_embed.t_embedder",
|
||||
"t_embedding_norm": "time_embed.norm",
|
||||
"blocks": "transformer_blocks",
|
||||
"adaln_modulation_self_attn.1": "norm1.linear_1",
|
||||
"adaln_modulation_self_attn.2": "norm1.linear_2",
|
||||
"adaln_modulation_cross_attn.1": "norm2.linear_1",
|
||||
"adaln_modulation_cross_attn.2": "norm2.linear_2",
|
||||
"adaln_modulation_mlp.1": "norm3.linear_1",
|
||||
"adaln_modulation_mlp.2": "norm3.linear_2",
|
||||
"self_attn": "attn1",
|
||||
"cross_attn": "attn2",
|
||||
"q_proj": "to_q",
|
||||
"k_proj": "to_k",
|
||||
"v_proj": "to_v",
|
||||
"output_proj": "to_out.0",
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
"mlp.layer1": "ff.net.0.proj",
|
||||
"mlp.layer2": "ff.net.2",
|
||||
"x_embedder.proj.1": "patch_embed.proj",
|
||||
# "extra_pos_embedder": "learnable_pos_embed",
|
||||
"final_layer.adaln_modulation.1": "norm_out.linear_1",
|
||||
"final_layer.adaln_modulation.2": "norm_out.linear_2",
|
||||
"final_layer.linear": "proj_out",
|
||||
}
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0 = {
|
||||
"accum_video_sample_counter": remove_keys_,
|
||||
"accum_image_sample_counter": remove_keys_,
|
||||
"accum_iteration": remove_keys_,
|
||||
"accum_train_in_hours": remove_keys_,
|
||||
"pos_embedder.seq": remove_keys_,
|
||||
"pos_embedder.dim_spatial_range": remove_keys_,
|
||||
"pos_embedder.dim_temporal_range": remove_keys_,
|
||||
"_extra_state": remove_keys_,
|
||||
}
|
||||
|
||||
|
||||
TRANSFORMER_CONFIGS = {
|
||||
"Cosmos-1.0-Diffusion-7B-Text2World": {
|
||||
"in_channels": 16,
|
||||
@@ -125,6 +174,66 @@ TRANSFORMER_CONFIGS = {
|
||||
"concat_padding_mask": True,
|
||||
"extra_pos_embed_type": "learnable",
|
||||
},
|
||||
"Cosmos-2.0-Diffusion-2B-Text2Image": {
|
||||
"in_channels": 16,
|
||||
"out_channels": 16,
|
||||
"num_attention_heads": 16,
|
||||
"attention_head_dim": 128,
|
||||
"num_layers": 28,
|
||||
"mlp_ratio": 4.0,
|
||||
"text_embed_dim": 1024,
|
||||
"adaln_lora_dim": 256,
|
||||
"max_size": (128, 240, 240),
|
||||
"patch_size": (1, 2, 2),
|
||||
"rope_scale": (1.0, 4.0, 4.0),
|
||||
"concat_padding_mask": True,
|
||||
"extra_pos_embed_type": None,
|
||||
},
|
||||
"Cosmos-2.0-Diffusion-14B-Text2Image": {
|
||||
"in_channels": 16,
|
||||
"out_channels": 16,
|
||||
"num_attention_heads": 40,
|
||||
"attention_head_dim": 128,
|
||||
"num_layers": 36,
|
||||
"mlp_ratio": 4.0,
|
||||
"text_embed_dim": 1024,
|
||||
"adaln_lora_dim": 256,
|
||||
"max_size": (128, 240, 240),
|
||||
"patch_size": (1, 2, 2),
|
||||
"rope_scale": (1.0, 4.0, 4.0),
|
||||
"concat_padding_mask": True,
|
||||
"extra_pos_embed_type": None,
|
||||
},
|
||||
"Cosmos-2.0-Diffusion-2B-Video2World": {
|
||||
"in_channels": 16 + 1,
|
||||
"out_channels": 16,
|
||||
"num_attention_heads": 16,
|
||||
"attention_head_dim": 128,
|
||||
"num_layers": 28,
|
||||
"mlp_ratio": 4.0,
|
||||
"text_embed_dim": 1024,
|
||||
"adaln_lora_dim": 256,
|
||||
"max_size": (128, 240, 240),
|
||||
"patch_size": (1, 2, 2),
|
||||
"rope_scale": (1.0, 3.0, 3.0),
|
||||
"concat_padding_mask": True,
|
||||
"extra_pos_embed_type": None,
|
||||
},
|
||||
"Cosmos-2.0-Diffusion-14B-Video2World": {
|
||||
"in_channels": 16 + 1,
|
||||
"out_channels": 16,
|
||||
"num_attention_heads": 40,
|
||||
"attention_head_dim": 128,
|
||||
"num_layers": 36,
|
||||
"mlp_ratio": 4.0,
|
||||
"text_embed_dim": 1024,
|
||||
"adaln_lora_dim": 256,
|
||||
"max_size": (128, 240, 240),
|
||||
"patch_size": (1, 2, 2),
|
||||
"rope_scale": (20 / 24, 2.0, 2.0),
|
||||
"concat_padding_mask": True,
|
||||
"extra_pos_embed_type": None,
|
||||
},
|
||||
}
|
||||
|
||||
VAE_KEYS_RENAME_DICT = {
|
||||
@@ -216,9 +325,18 @@ def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return state_dict
|
||||
|
||||
|
||||
def convert_transformer(transformer_type: str, ckpt_path: str):
|
||||
def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: bool = True):
|
||||
PREFIX_KEY = "net."
|
||||
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True))
|
||||
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=weights_only))
|
||||
|
||||
if "Cosmos-1.0" in transformer_type:
|
||||
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0
|
||||
elif "Cosmos-2.0" in transformer_type:
|
||||
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
|
||||
else:
|
||||
assert False
|
||||
|
||||
with init_empty_weights():
|
||||
config = TRANSFORMER_CONFIGS[transformer_type]
|
||||
@@ -281,13 +399,61 @@ def convert_vae(vae_type: str):
|
||||
return vae
|
||||
|
||||
|
||||
def save_pipeline_cosmos_1_0(args, transformer, vae):
|
||||
text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.bfloat16)
|
||||
tokenizer = T5TokenizerFast.from_pretrained(args.tokenizer_path)
|
||||
# The original code initializes EDM config with sigma_min=0.0002, but does not make use of it anywhere directly.
|
||||
# So, the sigma_min values that is used is the default value of 0.002.
|
||||
scheduler = EDMEulerScheduler(
|
||||
sigma_min=0.002,
|
||||
sigma_max=80,
|
||||
sigma_data=0.5,
|
||||
sigma_schedule="karras",
|
||||
num_train_timesteps=1000,
|
||||
prediction_type="epsilon",
|
||||
rho=7.0,
|
||||
final_sigmas_type="sigma_min",
|
||||
)
|
||||
|
||||
pipe_cls = CosmosTextToWorldPipeline if "Text2World" in args.transformer_type else CosmosVideoToWorldPipeline
|
||||
pipe = pipe_cls(
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
safety_checker=lambda *args, **kwargs: None,
|
||||
)
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
|
||||
def save_pipeline_cosmos_2_0(args, transformer, vae):
|
||||
text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.bfloat16)
|
||||
tokenizer = T5TokenizerFast.from_pretrained(args.tokenizer_path)
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(use_karras_sigmas=True)
|
||||
|
||||
pipe_cls = Cosmos2TextToImagePipeline if "Text2Image" in args.transformer_type else Cosmos2VideoToWorldPipeline
|
||||
pipe = pipe_cls(
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
safety_checker=lambda *args, **kwargs: None,
|
||||
)
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--transformer_type", type=str, default=None, choices=list(TRANSFORMER_CONFIGS.keys()))
|
||||
parser.add_argument(
|
||||
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
|
||||
)
|
||||
parser.add_argument("--vae_type", type=str, default=None, choices=list(VAE_CONFIGS.keys()), help="Type of VAE")
|
||||
parser.add_argument(
|
||||
"--vae_type", type=str, default=None, choices=["none", *list(VAE_CONFIGS.keys())], help="Type of VAE"
|
||||
)
|
||||
parser.add_argument("--text_encoder_path", type=str, default="google-t5/t5-11b")
|
||||
parser.add_argument("--tokenizer_path", type=str, default="google-t5/t5-11b")
|
||||
parser.add_argument("--save_pipeline", action="store_true")
|
||||
@@ -316,37 +482,26 @@ if __name__ == "__main__":
|
||||
assert args.tokenizer_path is not None
|
||||
|
||||
if args.transformer_ckpt_path is not None:
|
||||
transformer = convert_transformer(args.transformer_type, args.transformer_ckpt_path)
|
||||
weights_only = "Cosmos-1.0" in args.transformer_type
|
||||
transformer = convert_transformer(args.transformer_type, args.transformer_ckpt_path, weights_only)
|
||||
transformer = transformer.to(dtype=dtype)
|
||||
if not args.save_pipeline:
|
||||
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
if args.vae_type is not None:
|
||||
vae = convert_vae(args.vae_type)
|
||||
if "Cosmos-1.0" in args.transformer_type:
|
||||
vae = convert_vae(args.vae_type)
|
||||
else:
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
|
||||
)
|
||||
if not args.save_pipeline:
|
||||
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
if args.save_pipeline:
|
||||
text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_path, torch_dtype=dtype)
|
||||
tokenizer = T5TokenizerFast.from_pretrained(args.tokenizer_path)
|
||||
# The original code initializes EDM config with sigma_min=0.0002, but does not make use of it anywhere directly.
|
||||
# So, the sigma_min values that is used is the default value of 0.002.
|
||||
scheduler = EDMEulerScheduler(
|
||||
sigma_min=0.002,
|
||||
sigma_max=80,
|
||||
sigma_data=0.5,
|
||||
sigma_schedule="karras",
|
||||
num_train_timesteps=1000,
|
||||
prediction_type="epsilon",
|
||||
rho=7.0,
|
||||
final_sigmas_type="sigma_min",
|
||||
)
|
||||
|
||||
pipe = CosmosTextToWorldPipeline(
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
if "Cosmos-1.0" in args.transformer_type:
|
||||
save_pipeline_cosmos_1_0(args, transformer, vae)
|
||||
elif "Cosmos-2.0" in args.transformer_type:
|
||||
save_pipeline_cosmos_2_0(args, transformer, vae)
|
||||
else:
|
||||
assert False
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Run this script to convert the Stable Cascade model weights to a diffusers pipeline.
|
||||
# Run this script to convert the Stable Audio model weights to a diffusers pipeline.
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import argparse
|
||||
import pathlib
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
@@ -14,6 +14,8 @@ from diffusers import (
|
||||
WanImageToVideoPipeline,
|
||||
WanPipeline,
|
||||
WanTransformer3DModel,
|
||||
WanVACEPipeline,
|
||||
WanVACETransformer3DModel,
|
||||
)
|
||||
|
||||
|
||||
@@ -59,7 +61,52 @@ TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"attn2.norm_k_img": "attn2.norm_added_k",
|
||||
}
|
||||
|
||||
VACE_TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"time_embedding.0": "condition_embedder.time_embedder.linear_1",
|
||||
"time_embedding.2": "condition_embedder.time_embedder.linear_2",
|
||||
"text_embedding.0": "condition_embedder.text_embedder.linear_1",
|
||||
"text_embedding.2": "condition_embedder.text_embedder.linear_2",
|
||||
"time_projection.1": "condition_embedder.time_proj",
|
||||
"head.modulation": "scale_shift_table",
|
||||
"head.head": "proj_out",
|
||||
"modulation": "scale_shift_table",
|
||||
"ffn.0": "ffn.net.0.proj",
|
||||
"ffn.2": "ffn.net.2",
|
||||
# Hack to swap the layer names
|
||||
# The original model calls the norms in following order: norm1, norm3, norm2
|
||||
# We convert it to: norm1, norm2, norm3
|
||||
"norm2": "norm__placeholder",
|
||||
"norm3": "norm2",
|
||||
"norm__placeholder": "norm3",
|
||||
# # For the I2V model
|
||||
# "img_emb.proj.0": "condition_embedder.image_embedder.norm1",
|
||||
# "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
|
||||
# "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
|
||||
# "img_emb.proj.4": "condition_embedder.image_embedder.norm2",
|
||||
# # for the FLF2V model
|
||||
# "img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed",
|
||||
# Add attention component mappings
|
||||
"self_attn.q": "attn1.to_q",
|
||||
"self_attn.k": "attn1.to_k",
|
||||
"self_attn.v": "attn1.to_v",
|
||||
"self_attn.o": "attn1.to_out.0",
|
||||
"self_attn.norm_q": "attn1.norm_q",
|
||||
"self_attn.norm_k": "attn1.norm_k",
|
||||
"cross_attn.q": "attn2.to_q",
|
||||
"cross_attn.k": "attn2.to_k",
|
||||
"cross_attn.v": "attn2.to_v",
|
||||
"cross_attn.o": "attn2.to_out.0",
|
||||
"cross_attn.norm_q": "attn2.norm_q",
|
||||
"cross_attn.norm_k": "attn2.norm_k",
|
||||
"attn2.to_k_img": "attn2.add_k_proj",
|
||||
"attn2.to_v_img": "attn2.add_v_proj",
|
||||
"attn2.norm_k_img": "attn2.norm_added_k",
|
||||
"before_proj": "proj_in",
|
||||
"after_proj": "proj_out",
|
||||
}
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
|
||||
VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
|
||||
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
|
||||
@@ -74,7 +121,7 @@ def load_sharded_safetensors(dir: pathlib.Path):
|
||||
return state_dict
|
||||
|
||||
|
||||
def get_transformer_config(model_type: str) -> Dict[str, Any]:
|
||||
def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
|
||||
if model_type == "Wan-T2V-1.3B":
|
||||
config = {
|
||||
"model_id": "StevenZhang/Wan2.1-T2V-1.3B-Diff",
|
||||
@@ -94,6 +141,8 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
|
||||
"text_dim": 4096,
|
||||
},
|
||||
}
|
||||
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif model_type == "Wan-T2V-14B":
|
||||
config = {
|
||||
"model_id": "StevenZhang/Wan2.1-T2V-14B-Diff",
|
||||
@@ -113,6 +162,8 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
|
||||
"text_dim": 4096,
|
||||
},
|
||||
}
|
||||
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif model_type == "Wan-I2V-14B-480p":
|
||||
config = {
|
||||
"model_id": "StevenZhang/Wan2.1-I2V-14B-480P-Diff",
|
||||
@@ -133,6 +184,8 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
|
||||
"text_dim": 4096,
|
||||
},
|
||||
}
|
||||
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif model_type == "Wan-I2V-14B-720p":
|
||||
config = {
|
||||
"model_id": "StevenZhang/Wan2.1-I2V-14B-720P-Diff",
|
||||
@@ -153,6 +206,8 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
|
||||
"text_dim": 4096,
|
||||
},
|
||||
}
|
||||
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif model_type == "Wan-FLF2V-14B-720P":
|
||||
config = {
|
||||
"model_id": "ypyp/Wan2.1-FLF2V-14B-720P", # This is just a placeholder
|
||||
@@ -175,11 +230,60 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
|
||||
"pos_embed_seq_len": 257 * 2,
|
||||
},
|
||||
}
|
||||
return config
|
||||
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif model_type == "Wan-VACE-1.3B":
|
||||
config = {
|
||||
"model_id": "Wan-AI/Wan2.1-VACE-1.3B",
|
||||
"diffusers_config": {
|
||||
"added_kv_proj_dim": None,
|
||||
"attention_head_dim": 128,
|
||||
"cross_attn_norm": True,
|
||||
"eps": 1e-06,
|
||||
"ffn_dim": 8960,
|
||||
"freq_dim": 256,
|
||||
"in_channels": 16,
|
||||
"num_attention_heads": 12,
|
||||
"num_layers": 30,
|
||||
"out_channels": 16,
|
||||
"patch_size": [1, 2, 2],
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"text_dim": 4096,
|
||||
"vace_layers": [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28],
|
||||
"vace_in_channels": 96,
|
||||
},
|
||||
}
|
||||
RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif model_type == "Wan-VACE-14B":
|
||||
config = {
|
||||
"model_id": "Wan-AI/Wan2.1-VACE-14B",
|
||||
"diffusers_config": {
|
||||
"added_kv_proj_dim": None,
|
||||
"attention_head_dim": 128,
|
||||
"cross_attn_norm": True,
|
||||
"eps": 1e-06,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"in_channels": 16,
|
||||
"num_attention_heads": 40,
|
||||
"num_layers": 40,
|
||||
"out_channels": 16,
|
||||
"patch_size": [1, 2, 2],
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"text_dim": 4096,
|
||||
"vace_layers": [0, 5, 10, 15, 20, 25, 30, 35],
|
||||
"vace_in_channels": 96,
|
||||
},
|
||||
}
|
||||
RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
return config, RENAME_DICT, SPECIAL_KEYS_REMAP
|
||||
|
||||
|
||||
def convert_transformer(model_type: str):
|
||||
config = get_transformer_config(model_type)
|
||||
config, RENAME_DICT, SPECIAL_KEYS_REMAP = get_transformer_config(model_type)
|
||||
|
||||
diffusers_config = config["diffusers_config"]
|
||||
model_id = config["model_id"]
|
||||
model_dir = pathlib.Path(snapshot_download(model_id, repo_type="model"))
|
||||
@@ -187,16 +291,19 @@ def convert_transformer(model_type: str):
|
||||
original_state_dict = load_sharded_safetensors(model_dir)
|
||||
|
||||
with init_empty_weights():
|
||||
transformer = WanTransformer3DModel.from_config(diffusers_config)
|
||||
if "VACE" not in model_type:
|
||||
transformer = WanTransformer3DModel.from_config(diffusers_config)
|
||||
else:
|
||||
transformer = WanVACETransformer3DModel.from_config(diffusers_config)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
for replace_key, rename_key in RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_(original_state_dict, key, new_key)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
|
||||
for special_key, handler_fn_inplace in SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
@@ -412,7 +519,7 @@ def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_type", type=str, default=None)
|
||||
parser.add_argument("--output_path", type=str, required=True)
|
||||
parser.add_argument("--dtype", default="fp32")
|
||||
parser.add_argument("--dtype", default="fp32", choices=["fp32", "fp16", "bf16", "none"])
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -426,18 +533,20 @@ DTYPE_MAPPING = {
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
transformer = None
|
||||
dtype = DTYPE_MAPPING[args.dtype]
|
||||
|
||||
transformer = convert_transformer(args.model_type).to(dtype=dtype)
|
||||
transformer = convert_transformer(args.model_type)
|
||||
vae = convert_vae()
|
||||
text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl")
|
||||
text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl", torch_dtype=torch.bfloat16)
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
|
||||
flow_shift = 16.0 if "FLF2V" in args.model_type else 3.0
|
||||
scheduler = UniPCMultistepScheduler(
|
||||
prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift
|
||||
)
|
||||
|
||||
# If user has specified "none", we keep the original dtypes of the state dict without any conversion
|
||||
if args.dtype != "none":
|
||||
dtype = DTYPE_MAPPING[args.dtype]
|
||||
transformer.to(dtype)
|
||||
|
||||
if "I2V" in args.model_type or "FLF2V" in args.model_type:
|
||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
||||
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16
|
||||
@@ -452,6 +561,14 @@ if __name__ == "__main__":
|
||||
image_encoder=image_encoder,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
elif "VACE" in args.model_type:
|
||||
pipe = WanVACEPipeline(
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
else:
|
||||
pipe = WanPipeline(
|
||||
transformer=transformer,
|
||||
|
||||
@@ -159,6 +159,7 @@ else:
|
||||
"AutoencoderTiny",
|
||||
"AutoModel",
|
||||
"CacheMixin",
|
||||
"ChromaTransformer2DModel",
|
||||
"CogVideoXTransformer3DModel",
|
||||
"CogView3PlusTransformer2DModel",
|
||||
"CogView4Transformer2DModel",
|
||||
@@ -215,6 +216,7 @@ else:
|
||||
"UVit2DModel",
|
||||
"VQModel",
|
||||
"WanTransformer3DModel",
|
||||
"WanVACETransformer3DModel",
|
||||
]
|
||||
)
|
||||
_import_structure["optimization"] = [
|
||||
@@ -351,6 +353,7 @@ else:
|
||||
"AuraFlowPipeline",
|
||||
"BlipDiffusionControlNetPipeline",
|
||||
"BlipDiffusionPipeline",
|
||||
"ChromaPipeline",
|
||||
"CLIPImageProjection",
|
||||
"CogVideoXFunControlPipeline",
|
||||
"CogVideoXImageToVideoPipeline",
|
||||
@@ -360,6 +363,8 @@ else:
|
||||
"CogView4ControlPipeline",
|
||||
"CogView4Pipeline",
|
||||
"ConsisIDPipeline",
|
||||
"Cosmos2TextToImagePipeline",
|
||||
"Cosmos2VideoToWorldPipeline",
|
||||
"CosmosTextToWorldPipeline",
|
||||
"CosmosVideoToWorldPipeline",
|
||||
"CycleDiffusionPipeline",
|
||||
@@ -527,6 +532,7 @@ else:
|
||||
"VQDiffusionPipeline",
|
||||
"WanImageToVideoPipeline",
|
||||
"WanPipeline",
|
||||
"WanVACEPipeline",
|
||||
"WanVideoToVideoPipeline",
|
||||
"WuerstchenCombinedPipeline",
|
||||
"WuerstchenDecoderPipeline",
|
||||
@@ -766,6 +772,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderTiny,
|
||||
AutoModel,
|
||||
CacheMixin,
|
||||
ChromaTransformer2DModel,
|
||||
CogVideoXTransformer3DModel,
|
||||
CogView3PlusTransformer2DModel,
|
||||
CogView4Transformer2DModel,
|
||||
@@ -821,6 +828,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
UVit2DModel,
|
||||
VQModel,
|
||||
WanTransformer3DModel,
|
||||
WanVACETransformer3DModel,
|
||||
)
|
||||
from .optimization import (
|
||||
get_constant_schedule,
|
||||
@@ -937,6 +945,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AudioLDM2UNet2DConditionModel,
|
||||
AudioLDMPipeline,
|
||||
AuraFlowPipeline,
|
||||
ChromaPipeline,
|
||||
CLIPImageProjection,
|
||||
CogVideoXFunControlPipeline,
|
||||
CogVideoXImageToVideoPipeline,
|
||||
@@ -946,6 +955,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
CogView4ControlPipeline,
|
||||
CogView4Pipeline,
|
||||
ConsisIDPipeline,
|
||||
Cosmos2TextToImagePipeline,
|
||||
Cosmos2VideoToWorldPipeline,
|
||||
CosmosTextToWorldPipeline,
|
||||
CosmosVideoToWorldPipeline,
|
||||
CycleDiffusionPipeline,
|
||||
@@ -1112,6 +1123,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
VQDiffusionPipeline,
|
||||
WanImageToVideoPipeline,
|
||||
WanPipeline,
|
||||
WanVACEPipeline,
|
||||
WanVideoToVideoPipeline,
|
||||
WuerstchenCombinedPipeline,
|
||||
WuerstchenDecoderPipeline,
|
||||
|
||||
@@ -159,10 +159,7 @@ class IPAdapterMixin:
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
||||
state_dicts = []
|
||||
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
|
||||
pretrained_model_name_or_path_or_dict, weight_name, subfolder
|
||||
@@ -465,10 +462,7 @@ class FluxIPAdapterMixin:
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
||||
state_dicts = []
|
||||
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
|
||||
pretrained_model_name_or_path_or_dict, weight_name, subfolder
|
||||
@@ -750,10 +744,7 @@ class SD3IPAdapterMixin:
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
||||
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
model_file = _get_model_file(
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
@@ -45,6 +46,7 @@ from ..utils import (
|
||||
set_adapter_layers,
|
||||
set_weights_and_activate_adapters,
|
||||
)
|
||||
from ..utils.state_dict_utils import _load_sft_state_dict_metadata
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
@@ -62,6 +64,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
|
||||
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
|
||||
LORA_ADAPTER_METADATA_KEY = "lora_adapter_metadata"
|
||||
|
||||
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
|
||||
@@ -206,6 +209,7 @@ def _fetch_state_dict(
|
||||
subfolder,
|
||||
user_agent,
|
||||
allow_pickle,
|
||||
metadata=None,
|
||||
):
|
||||
model_file = None
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
@@ -236,11 +240,14 @@ def _fetch_state_dict(
|
||||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
||||
metadata = _load_sft_state_dict_metadata(model_file)
|
||||
|
||||
except (IOError, safetensors.SafetensorError) as e:
|
||||
if not allow_pickle:
|
||||
raise e
|
||||
# try loading non-safetensors weights
|
||||
model_file = None
|
||||
metadata = None
|
||||
pass
|
||||
|
||||
if model_file is None:
|
||||
@@ -261,10 +268,11 @@ def _fetch_state_dict(
|
||||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = load_state_dict(model_file)
|
||||
metadata = None
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
return state_dict
|
||||
return state_dict, metadata
|
||||
|
||||
|
||||
def _best_guess_weight_name(
|
||||
@@ -306,6 +314,11 @@ def _best_guess_weight_name(
|
||||
return weight_name
|
||||
|
||||
|
||||
def _pack_dict_with_prefix(state_dict, prefix):
|
||||
sd_with_prefix = {f"{prefix}.{key}": value for key, value in state_dict.items()}
|
||||
return sd_with_prefix
|
||||
|
||||
|
||||
def _load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
network_alphas,
|
||||
@@ -317,10 +330,14 @@ def _load_lora_into_text_encoder(
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
hotswap: bool = False,
|
||||
metadata=None,
|
||||
):
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
if network_alphas and metadata:
|
||||
raise ValueError("`network_alphas` and `metadata` cannot be specified both at the same time.")
|
||||
|
||||
peft_kwargs = {}
|
||||
if low_cpu_mem_usage:
|
||||
if not is_peft_version(">=", "0.13.1"):
|
||||
@@ -349,6 +366,8 @@ def _load_lora_into_text_encoder(
|
||||
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||
if prefix is not None:
|
||||
state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
|
||||
if metadata is not None:
|
||||
metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")}
|
||||
|
||||
if len(state_dict) > 0:
|
||||
logger.info(f"Loading {prefix}.")
|
||||
@@ -376,7 +395,10 @@ def _load_lora_into_text_encoder(
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||
network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
|
||||
if metadata is not None:
|
||||
lora_config_kwargs = metadata
|
||||
else:
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
|
||||
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
@@ -398,7 +420,10 @@ def _load_lora_into_text_encoder(
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
lora_config_kwargs.pop("lora_bias")
|
||||
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
try:
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
except TypeError as e:
|
||||
raise TypeError("`LoraConfig` class could not be instantiated.") from e
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
@@ -889,8 +914,7 @@ class LoraBaseMixin:
|
||||
@staticmethod
|
||||
def pack_weights(layers, prefix):
|
||||
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
|
||||
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
||||
return layers_state_dict
|
||||
return _pack_dict_with_prefix(layers_weights, prefix)
|
||||
|
||||
@staticmethod
|
||||
def write_lora_layers(
|
||||
@@ -900,16 +924,32 @@ class LoraBaseMixin:
|
||||
weight_name: str,
|
||||
save_function: Callable,
|
||||
safe_serialization: bool,
|
||||
lora_adapter_metadata: Optional[dict] = None,
|
||||
):
|
||||
if os.path.isfile(save_directory):
|
||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
if lora_adapter_metadata and not safe_serialization:
|
||||
raise ValueError("`lora_adapter_metadata` cannot be specified when not using `safe_serialization`.")
|
||||
if lora_adapter_metadata and not isinstance(lora_adapter_metadata, dict):
|
||||
raise TypeError("`lora_adapter_metadata` must be of type `dict`.")
|
||||
|
||||
if save_function is None:
|
||||
if safe_serialization:
|
||||
|
||||
def save_function(weights, filename):
|
||||
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
|
||||
# Inject framework format.
|
||||
metadata = {"format": "pt"}
|
||||
if lora_adapter_metadata:
|
||||
for key, value in lora_adapter_metadata.items():
|
||||
if isinstance(value, set):
|
||||
lora_adapter_metadata[key] = list(value)
|
||||
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(
|
||||
lora_adapter_metadata, indent=2, sort_keys=True
|
||||
)
|
||||
|
||||
return safetensors.torch.save_file(weights, filename, metadata=metadata)
|
||||
|
||||
else:
|
||||
save_function = torch.save
|
||||
|
||||
@@ -1596,7 +1596,10 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
||||
converted_state_dict = {}
|
||||
original_state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()}
|
||||
|
||||
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in original_state_dict if "blocks." in k})
|
||||
block_numbers = {int(k.split(".")[1]) for k in original_state_dict if k.startswith("blocks.")}
|
||||
min_block = min(block_numbers)
|
||||
max_block = max(block_numbers)
|
||||
|
||||
is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
|
||||
lora_down_key = "lora_A" if any("lora_A" in k for k in original_state_dict) else "lora_down"
|
||||
lora_up_key = "lora_B" if any("lora_B" in k for k in original_state_dict) else "lora_up"
|
||||
@@ -1605,76 +1608,105 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
||||
if diff_keys:
|
||||
for diff_k in diff_keys:
|
||||
param = original_state_dict[diff_k]
|
||||
# The magnitudes of the .diff-ending weights are very low (most are below 1e-4, some are upto 1e-3,
|
||||
# and 2 of them are about 1.6e-2 [the case with AccVideo lora]). The low magnitudes mostly correspond
|
||||
# to norm layers. Ignoring them is the best option at the moment until a better solution is found. It
|
||||
# is okay to ignore because they do not affect the model output in a significant manner.
|
||||
threshold = 1.6e-2
|
||||
absdiff = param.abs().max() - param.abs().min()
|
||||
all_zero = torch.all(param == 0).item()
|
||||
if all_zero:
|
||||
logger.debug(f"Removed {diff_k} key from the state dict as it's all zeros.")
|
||||
all_absdiff_lower_than_threshold = absdiff < threshold
|
||||
if all_zero or all_absdiff_lower_than_threshold:
|
||||
logger.debug(
|
||||
f"Removed {diff_k} key from the state dict as it's all zeros, or values lower than hardcoded threshold."
|
||||
)
|
||||
original_state_dict.pop(diff_k)
|
||||
|
||||
# For the `diff_b` keys, we treat them as lora_bias.
|
||||
# https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.lora_bias
|
||||
|
||||
for i in range(num_blocks):
|
||||
for i in range(min_block, max_block + 1):
|
||||
# Self-attention
|
||||
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
|
||||
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_A.weight"] = original_state_dict.pop(
|
||||
f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight"
|
||||
)
|
||||
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = original_state_dict.pop(
|
||||
f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight"
|
||||
)
|
||||
if f"blocks.{i}.self_attn.{o}.diff_b" in original_state_dict:
|
||||
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.bias"] = original_state_dict.pop(
|
||||
f"blocks.{i}.self_attn.{o}.diff_b"
|
||||
)
|
||||
original_key = f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight"
|
||||
converted_key = f"blocks.{i}.attn1.{c}.lora_A.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
|
||||
original_key = f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight"
|
||||
converted_key = f"blocks.{i}.attn1.{c}.lora_B.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
|
||||
original_key = f"blocks.{i}.self_attn.{o}.diff_b"
|
||||
converted_key = f"blocks.{i}.attn1.{c}.lora_B.bias"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
|
||||
# Cross-attention
|
||||
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
|
||||
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
|
||||
f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
|
||||
)
|
||||
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
|
||||
f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
|
||||
)
|
||||
if f"blocks.{i}.cross_attn.{o}.diff_b" in original_state_dict:
|
||||
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.bias"] = original_state_dict.pop(
|
||||
f"blocks.{i}.cross_attn.{o}.diff_b"
|
||||
)
|
||||
original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
|
||||
converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
|
||||
original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
|
||||
converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
|
||||
original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
|
||||
converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
|
||||
if is_i2v_lora:
|
||||
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
|
||||
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
|
||||
f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
|
||||
)
|
||||
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
|
||||
f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
|
||||
)
|
||||
if f"blocks.{i}.cross_attn.{o}.diff_b" in original_state_dict:
|
||||
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.bias"] = original_state_dict.pop(
|
||||
f"blocks.{i}.cross_attn.{o}.diff_b"
|
||||
)
|
||||
original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
|
||||
converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
|
||||
original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
|
||||
converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
|
||||
original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
|
||||
converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
|
||||
# FFN
|
||||
for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
|
||||
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_A.weight"] = original_state_dict.pop(
|
||||
f"blocks.{i}.{o}.{lora_down_key}.weight"
|
||||
)
|
||||
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.weight"] = original_state_dict.pop(
|
||||
f"blocks.{i}.{o}.{lora_up_key}.weight"
|
||||
)
|
||||
if f"blocks.{i}.{o}.diff_b" in original_state_dict:
|
||||
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.bias"] = original_state_dict.pop(
|
||||
f"blocks.{i}.{o}.diff_b"
|
||||
)
|
||||
original_key = f"blocks.{i}.{o}.{lora_down_key}.weight"
|
||||
converted_key = f"blocks.{i}.ffn.{c}.lora_A.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
|
||||
original_key = f"blocks.{i}.{o}.{lora_up_key}.weight"
|
||||
converted_key = f"blocks.{i}.ffn.{c}.lora_B.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
|
||||
original_key = f"blocks.{i}.{o}.diff_b"
|
||||
converted_key = f"blocks.{i}.ffn.{c}.lora_B.bias"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
|
||||
# Remaining.
|
||||
if original_state_dict:
|
||||
if any("time_projection" in k for k in original_state_dict):
|
||||
converted_state_dict["condition_embedder.time_proj.lora_A.weight"] = original_state_dict.pop(
|
||||
f"time_projection.1.{lora_down_key}.weight"
|
||||
)
|
||||
converted_state_dict["condition_embedder.time_proj.lora_B.weight"] = original_state_dict.pop(
|
||||
f"time_projection.1.{lora_up_key}.weight"
|
||||
)
|
||||
original_key = f"time_projection.1.{lora_down_key}.weight"
|
||||
converted_key = "condition_embedder.time_proj.lora_A.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
|
||||
original_key = f"time_projection.1.{lora_up_key}.weight"
|
||||
converted_key = "condition_embedder.time_proj.lora_B.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
|
||||
if "time_projection.1.diff_b" in original_state_dict:
|
||||
converted_state_dict["condition_embedder.time_proj.lora_B.bias"] = original_state_dict.pop(
|
||||
"time_projection.1.diff_b"
|
||||
@@ -1709,6 +1741,20 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
||||
original_state_dict.pop(f"{text_time}.{b_n}.diff_b")
|
||||
)
|
||||
|
||||
for img_ours, img_theirs in [
|
||||
("ff.net.0.proj", "img_emb.proj.1"),
|
||||
("ff.net.2", "img_emb.proj.3"),
|
||||
]:
|
||||
original_key = f"{img_theirs}.{lora_down_key}.weight"
|
||||
converted_key = f"condition_embedder.image_embedder.{img_ours}.lora_A.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
|
||||
original_key = f"{img_theirs}.{lora_up_key}.weight"
|
||||
converted_key = f"condition_embedder.image_embedder.{img_ours}.lora_B.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
|
||||
if len(original_state_dict) > 0:
|
||||
diff = all(".diff" in k for k in original_state_dict)
|
||||
if diff:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
@@ -58,6 +59,8 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
|
||||
"CogView4Transformer2DModel": lambda model_cls, weights: weights,
|
||||
"HiDreamImageTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"WanVACETransformer3DModel": lambda model_cls, weights: weights,
|
||||
"ChromaTransformer2DModel": lambda model_cls, weights: weights,
|
||||
}
|
||||
|
||||
|
||||
@@ -184,6 +187,9 @@ class PeftAdapterMixin:
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
metadata:
|
||||
LoRA adapter metadata. When supplied, the metadata inferred through the state dict isn't used to
|
||||
initialize `LoraConfig`.
|
||||
"""
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
@@ -201,6 +207,7 @@ class PeftAdapterMixin:
|
||||
network_alphas = kwargs.pop("network_alphas", None)
|
||||
_pipeline = kwargs.pop("_pipeline", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
|
||||
metadata = kwargs.pop("metadata", None)
|
||||
allow_pickle = False
|
||||
|
||||
if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"):
|
||||
@@ -208,12 +215,9 @@ class PeftAdapterMixin:
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
||||
|
||||
state_dict = _fetch_state_dict(
|
||||
state_dict, metadata = _fetch_state_dict(
|
||||
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
||||
weight_name=weight_name,
|
||||
use_safetensors=use_safetensors,
|
||||
@@ -226,12 +230,17 @@ class PeftAdapterMixin:
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
allow_pickle=allow_pickle,
|
||||
metadata=metadata,
|
||||
)
|
||||
if network_alphas is not None and prefix is None:
|
||||
raise ValueError("`network_alphas` cannot be None when `prefix` is None.")
|
||||
if network_alphas and metadata:
|
||||
raise ValueError("Both `network_alphas` and `metadata` cannot be specified.")
|
||||
|
||||
if prefix is not None:
|
||||
state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
|
||||
if metadata is not None:
|
||||
metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")}
|
||||
|
||||
if len(state_dict) > 0:
|
||||
if adapter_name in getattr(self, "peft_config", {}) and not hotswap:
|
||||
@@ -266,7 +275,12 @@ class PeftAdapterMixin:
|
||||
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
|
||||
}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
|
||||
if metadata is not None:
|
||||
lora_config_kwargs = metadata
|
||||
else:
|
||||
lora_config_kwargs = get_peft_kwargs(
|
||||
rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict
|
||||
)
|
||||
_maybe_raise_error_for_ambiguity(lora_config_kwargs)
|
||||
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
@@ -289,7 +303,11 @@ class PeftAdapterMixin:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
lora_config_kwargs.pop("lora_bias")
|
||||
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
try:
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
except TypeError as e:
|
||||
raise TypeError("`LoraConfig` class could not be instantiated.") from e
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(self)
|
||||
@@ -444,17 +462,13 @@ class PeftAdapterMixin:
|
||||
underlying model has multiple adapters loaded.
|
||||
upcast_before_saving (`bool`, defaults to `False`):
|
||||
Whether to cast the underlying model to `torch.float32` before serialization.
|
||||
save_function (`Callable`):
|
||||
The function to use to save the state dictionary. Useful during distributed training when you need to
|
||||
replace `torch.save` with another method. Can be configured with the environment variable
|
||||
`DIFFUSERS_SAVE_MODE`.
|
||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
||||
weight_name: (`str`, *optional*, defaults to `None`): Name of the file to serialize the state dict with.
|
||||
"""
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
|
||||
from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
|
||||
from .lora_base import LORA_ADAPTER_METADATA_KEY, LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
|
||||
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(self)
|
||||
@@ -462,6 +476,8 @@ class PeftAdapterMixin:
|
||||
if adapter_name not in getattr(self, "peft_config", {}):
|
||||
raise ValueError(f"Adapter name {adapter_name} not found in the model.")
|
||||
|
||||
lora_adapter_metadata = self.peft_config[adapter_name].to_dict()
|
||||
|
||||
lora_layers_to_save = get_peft_model_state_dict(
|
||||
self.to(dtype=torch.float32 if upcast_before_saving else None), adapter_name=adapter_name
|
||||
)
|
||||
@@ -471,7 +487,15 @@ class PeftAdapterMixin:
|
||||
if safe_serialization:
|
||||
|
||||
def save_function(weights, filename):
|
||||
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
|
||||
# Inject framework format.
|
||||
metadata = {"format": "pt"}
|
||||
if lora_adapter_metadata is not None:
|
||||
for key, value in lora_adapter_metadata.items():
|
||||
if isinstance(value, set):
|
||||
lora_adapter_metadata[key] = list(value)
|
||||
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
|
||||
|
||||
return safetensors.torch.save_file(weights, filename, metadata=metadata)
|
||||
|
||||
else:
|
||||
save_function = torch.save
|
||||
@@ -484,7 +508,6 @@ class PeftAdapterMixin:
|
||||
else:
|
||||
weight_name = LORA_WEIGHT_NAME
|
||||
|
||||
# TODO: we could consider saving the `peft_config` as well.
|
||||
save_path = Path(save_directory, weight_name).as_posix()
|
||||
save_function(lora_layers_to_save, save_path)
|
||||
logger.info(f"Model weights saved in {save_path}")
|
||||
|
||||
@@ -29,6 +29,7 @@ from .single_file_utils import (
|
||||
convert_animatediff_checkpoint_to_diffusers,
|
||||
convert_auraflow_transformer_checkpoint_to_diffusers,
|
||||
convert_autoencoder_dc_checkpoint_to_diffusers,
|
||||
convert_chroma_transformer_checkpoint_to_diffusers,
|
||||
convert_controlnet_checkpoint,
|
||||
convert_flux_transformer_checkpoint_to_diffusers,
|
||||
convert_hidream_transformer_to_diffusers,
|
||||
@@ -97,6 +98,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"ChromaTransformer2DModel": {
|
||||
"checkpoint_mapping_fn": convert_chroma_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"LTXVideoTransformer3DModel": {
|
||||
"checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
|
||||
@@ -3310,3 +3310,172 @@ def convert_hidream_transformer_to_diffusers(checkpoint, **kwargs):
|
||||
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
|
||||
|
||||
return checkpoint
|
||||
|
||||
|
||||
def convert_chroma_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict = {}
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
for k in keys:
|
||||
if "model.diffusion_model." in k:
|
||||
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
|
||||
|
||||
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401
|
||||
num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401
|
||||
num_guidance_layers = (
|
||||
list(set(int(k.split(".", 3)[2]) for k in checkpoint if "distilled_guidance_layer.layers." in k))[-1] + 1 # noqa: C401
|
||||
)
|
||||
mlp_ratio = 4.0
|
||||
inner_dim = 3072
|
||||
|
||||
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
|
||||
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
|
||||
def swap_scale_shift(weight):
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
return new_weight
|
||||
|
||||
# guidance
|
||||
converted_state_dict["distilled_guidance_layer.in_proj.bias"] = checkpoint.pop(
|
||||
"distilled_guidance_layer.in_proj.bias"
|
||||
)
|
||||
converted_state_dict["distilled_guidance_layer.in_proj.weight"] = checkpoint.pop(
|
||||
"distilled_guidance_layer.in_proj.weight"
|
||||
)
|
||||
converted_state_dict["distilled_guidance_layer.out_proj.bias"] = checkpoint.pop(
|
||||
"distilled_guidance_layer.out_proj.bias"
|
||||
)
|
||||
converted_state_dict["distilled_guidance_layer.out_proj.weight"] = checkpoint.pop(
|
||||
"distilled_guidance_layer.out_proj.weight"
|
||||
)
|
||||
for i in range(num_guidance_layers):
|
||||
block_prefix = f"distilled_guidance_layer.layers.{i}."
|
||||
converted_state_dict[f"{block_prefix}linear_1.bias"] = checkpoint.pop(
|
||||
f"distilled_guidance_layer.layers.{i}.in_layer.bias"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}linear_1.weight"] = checkpoint.pop(
|
||||
f"distilled_guidance_layer.layers.{i}.in_layer.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}linear_2.bias"] = checkpoint.pop(
|
||||
f"distilled_guidance_layer.layers.{i}.out_layer.bias"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}linear_2.weight"] = checkpoint.pop(
|
||||
f"distilled_guidance_layer.layers.{i}.out_layer.weight"
|
||||
)
|
||||
converted_state_dict[f"distilled_guidance_layer.norms.{i}.weight"] = checkpoint.pop(
|
||||
f"distilled_guidance_layer.norms.{i}.scale"
|
||||
)
|
||||
|
||||
# context_embedder
|
||||
converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight")
|
||||
converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias")
|
||||
|
||||
# x_embedder
|
||||
converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight")
|
||||
converted_state_dict["x_embedder.bias"] = checkpoint.pop("img_in.bias")
|
||||
|
||||
# double transformer blocks
|
||||
for i in range(num_layers):
|
||||
block_prefix = f"transformer_blocks.{i}."
|
||||
# Q, K, V
|
||||
sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0)
|
||||
context_q, context_k, context_v = torch.chunk(
|
||||
checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0
|
||||
)
|
||||
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
|
||||
checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0
|
||||
)
|
||||
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
|
||||
checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q])
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k])
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v])
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q])
|
||||
converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k])
|
||||
converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v])
|
||||
converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias])
|
||||
# qk_norm
|
||||
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.img_attn.norm.query_norm.scale"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.img_attn.norm.key_norm.scale"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
|
||||
)
|
||||
# ff img_mlp
|
||||
converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.img_mlp.0.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.0.bias")
|
||||
converted_state_dict[f"{block_prefix}ff.net.2.weight"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.weight")
|
||||
converted_state_dict[f"{block_prefix}ff.net.2.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.bias")
|
||||
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.txt_mlp.0.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.txt_mlp.0.bias"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.txt_mlp.2.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.txt_mlp.2.bias"
|
||||
)
|
||||
# output projections.
|
||||
converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.img_attn.proj.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.img_attn.proj.bias"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.txt_attn.proj.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.txt_attn.proj.bias"
|
||||
)
|
||||
|
||||
# single transformer blocks
|
||||
for i in range(num_single_layers):
|
||||
block_prefix = f"single_transformer_blocks.{i}."
|
||||
# Q, K, V, mlp
|
||||
mlp_hidden_dim = int(inner_dim * mlp_ratio)
|
||||
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
|
||||
q, k, v, mlp = torch.split(checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
|
||||
q_bias, k_bias, v_bias, mlp_bias = torch.split(
|
||||
checkpoint.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q])
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k])
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v])
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias])
|
||||
converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp])
|
||||
converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias])
|
||||
# qk norm
|
||||
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
|
||||
f"single_blocks.{i}.norm.query_norm.scale"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
|
||||
f"single_blocks.{i}.norm.key_norm.scale"
|
||||
)
|
||||
# output projections.
|
||||
converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight")
|
||||
converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias")
|
||||
|
||||
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
|
||||
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
@@ -155,10 +155,7 @@ class UNet2DConditionLoadersMixin:
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
||||
|
||||
model_file = None
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
|
||||
@@ -74,6 +74,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
|
||||
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_cosmos"] = ["CosmosTransformer3DModel"]
|
||||
@@ -89,6 +90,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
|
||||
_import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"]
|
||||
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
|
||||
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
|
||||
_import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"]
|
||||
@@ -150,6 +152,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .transformers import (
|
||||
AllegroTransformer3DModel,
|
||||
AuraFlowTransformer2DModel,
|
||||
ChromaTransformer2DModel,
|
||||
CogVideoXTransformer3DModel,
|
||||
CogView3PlusTransformer2DModel,
|
||||
CogView4Transformer2DModel,
|
||||
@@ -178,6 +181,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Transformer2DModel,
|
||||
TransformerTemporalModel,
|
||||
WanTransformer3DModel,
|
||||
WanVACETransformer3DModel,
|
||||
)
|
||||
from .unets import (
|
||||
I2VGenXLUNet,
|
||||
|
||||
@@ -749,6 +749,16 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
self.tile_sample_stride_height = 192
|
||||
self.tile_sample_stride_width = 192
|
||||
|
||||
# Precompute and cache conv counts for encoder and decoder for clear_cache speedup
|
||||
self._cached_conv_counts = {
|
||||
"decoder": sum(isinstance(m, WanCausalConv3d) for m in self.decoder.modules())
|
||||
if self.decoder is not None
|
||||
else 0,
|
||||
"encoder": sum(isinstance(m, WanCausalConv3d) for m in self.encoder.modules())
|
||||
if self.encoder is not None
|
||||
else 0,
|
||||
}
|
||||
|
||||
def enable_tiling(
|
||||
self,
|
||||
tile_sample_min_height: Optional[int] = None,
|
||||
@@ -801,18 +811,12 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
self.use_slicing = False
|
||||
|
||||
def clear_cache(self):
|
||||
def _count_conv3d(model):
|
||||
count = 0
|
||||
for m in model.modules():
|
||||
if isinstance(m, WanCausalConv3d):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
self._conv_num = _count_conv3d(self.decoder)
|
||||
# Use cached conv counts for decoder and encoder to avoid re-iterating modules each call
|
||||
self._conv_num = self._cached_conv_counts["decoder"]
|
||||
self._conv_idx = [0]
|
||||
self._feat_map = [None] * self._conv_num
|
||||
# cache encode
|
||||
self._enc_conv_num = _count_conv3d(self.encoder)
|
||||
self._enc_conv_num = self._cached_conv_counts["encoder"]
|
||||
self._enc_conv_idx = [0]
|
||||
self._enc_feat_map = [None] * self._enc_conv_num
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ def get_timestep_embedding(
|
||||
downscale_freq_shift: float = 1,
|
||||
scale: float = 1,
|
||||
max_period: int = 10000,
|
||||
):
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
||||
|
||||
@@ -1149,9 +1149,7 @@ def get_1d_rotary_pos_embed(
|
||||
|
||||
theta = theta * ntk_factor
|
||||
freqs = (
|
||||
1.0
|
||||
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
|
||||
/ linear_factor
|
||||
1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device) / dim)) / linear_factor
|
||||
) # [D/2]
|
||||
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
|
||||
is_npu = freqs.device.type == "npu"
|
||||
@@ -1327,7 +1325,7 @@ class Timesteps(nn.Module):
|
||||
self.downscale_freq_shift = downscale_freq_shift
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, timesteps):
|
||||
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
t_emb = get_timestep_embedding(
|
||||
timesteps,
|
||||
self.num_channels,
|
||||
|
||||
@@ -814,14 +814,43 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
|
||||
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
||||
information.
|
||||
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
device_map (`Union[int, str, torch.device]` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
A map that specifies where each submodule should go. It doesn't need to be defined for each
|
||||
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
|
||||
same device. Defaults to `None`, meaning that the model will be loaded on CPU.
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
>>> from diffusers import AutoModel
|
||||
>>> import torch
|
||||
|
||||
>>> # This works.
|
||||
>>> model = AutoModel.from_pretrained(
|
||||
... "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map="cuda"
|
||||
... )
|
||||
>>> # This also works (integer accelerator device ID).
|
||||
>>> model = AutoModel.from_pretrained(
|
||||
... "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map=0
|
||||
... )
|
||||
>>> # Specifying a supported offloading strategy like "auto" also works.
|
||||
>>> model = AutoModel.from_pretrained(
|
||||
... "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map="auto"
|
||||
... )
|
||||
>>> # Specifying a dictionary as `device_map` also works.
|
||||
>>> model = AutoModel.from_pretrained(
|
||||
... "stabilityai/stable-diffusion-xl-base-1.0",
|
||||
... subfolder="unet",
|
||||
... device_map={"": torch.device("cuda")},
|
||||
... )
|
||||
```
|
||||
|
||||
Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
|
||||
more information about each option see [designing a device
|
||||
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
||||
map](https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#the-devicemap). You
|
||||
can also refer to the [Diffusers-specific
|
||||
documentation](https://huggingface.co/docs/diffusers/main/en/training/distributed_inference#model-sharding)
|
||||
for more concrete examples.
|
||||
max_memory (`Dict`, *optional*):
|
||||
A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
|
||||
each GPU and the available CPU RAM if unset.
|
||||
@@ -1387,7 +1416,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
low_cpu_mem_usage: bool = True,
|
||||
dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
keep_in_fp32_modules: Optional[List[str]] = None,
|
||||
device_map: Dict[str, Union[int, str, torch.device]] = None,
|
||||
device_map: Union[str, int, torch.device, Dict[str, Union[int, str, torch.device]]] = None,
|
||||
offload_state_dict: Optional[bool] = None,
|
||||
offload_folder: Optional[Union[str, os.PathLike]] = None,
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
|
||||
|
||||
@@ -17,6 +17,7 @@ if is_torch_available():
|
||||
from .t5_film_transformer import T5FilmDecoder
|
||||
from .transformer_2d import Transformer2DModel
|
||||
from .transformer_allegro import AllegroTransformer3DModel
|
||||
from .transformer_chroma import ChromaTransformer2DModel
|
||||
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
|
||||
from .transformer_cogview4 import CogView4Transformer2DModel
|
||||
from .transformer_cosmos import CosmosTransformer3DModel
|
||||
@@ -32,3 +33,4 @@ if is_torch_available():
|
||||
from .transformer_sd3 import SD3Transformer2DModel
|
||||
from .transformer_temporal import TransformerTemporalModel
|
||||
from .transformer_wan import WanTransformer3DModel
|
||||
from .transformer_wan_vace import WanVACETransformer3DModel
|
||||
|
||||
@@ -0,0 +1,732 @@
|
||||
# Copyright 2025 Black Forest Labs, The HuggingFace Team and loadstone-rock . All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.import_utils import is_torch_npu_available
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import (
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
FluxAttnProcessor2_0,
|
||||
FluxAttnProcessor2_0_NPU,
|
||||
FusedFluxAttnProcessor2_0,
|
||||
)
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import FluxPosEmbed, PixArtAlphaTextProjection, Timesteps, get_timestep_embedding
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import CombinedTimestepLabelEmbeddings, FP32LayerNorm, RMSNorm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class ChromaAdaLayerNormZeroPruned(nn.Module):
|
||||
r"""
|
||||
Norm layer adaptive layer norm zero (adaLN-Zero).
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
num_embeddings (`int`): The size of the embeddings dictionary.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, norm_type="layer_norm", bias=True):
|
||||
super().__init__()
|
||||
if num_embeddings is not None:
|
||||
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
|
||||
else:
|
||||
self.emb = None
|
||||
|
||||
if norm_type == "layer_norm":
|
||||
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
||||
elif norm_type == "fp32_layer_norm":
|
||||
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timestep: Optional[torch.Tensor] = None,
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
hidden_dtype: Optional[torch.dtype] = None,
|
||||
emb: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
if self.emb is not None:
|
||||
emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.flatten(1, 2).chunk(6, dim=1)
|
||||
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
||||
|
||||
|
||||
class ChromaAdaLayerNormZeroSinglePruned(nn.Module):
|
||||
r"""
|
||||
Norm layer adaptive layer norm zero (adaLN-Zero).
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
num_embeddings (`int`): The size of the embeddings dictionary.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
|
||||
super().__init__()
|
||||
|
||||
if norm_type == "layer_norm":
|
||||
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
emb: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
shift_msa, scale_msa, gate_msa = emb.flatten(1, 2).chunk(3, dim=1)
|
||||
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||
return x, gate_msa
|
||||
|
||||
|
||||
class ChromaAdaLayerNormContinuousPruned(nn.Module):
|
||||
r"""
|
||||
Adaptive normalization layer with a norm layer (layer_norm or rms_norm).
|
||||
|
||||
Args:
|
||||
embedding_dim (`int`): Embedding dimension to use during projection.
|
||||
conditioning_embedding_dim (`int`): Dimension of the input condition.
|
||||
elementwise_affine (`bool`, defaults to `True`):
|
||||
Boolean flag to denote if affine transformation should be applied.
|
||||
eps (`float`, defaults to 1e-5): Epsilon factor.
|
||||
bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
|
||||
norm_type (`str`, defaults to `"layer_norm"`):
|
||||
Normalization layer to use. Values supported: "layer_norm", "rms_norm".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
conditioning_embedding_dim: int,
|
||||
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
|
||||
# because the output is immediately scaled and shifted by the projected conditioning embeddings.
|
||||
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
|
||||
# However, this is how it was implemented in the original code, and it's rather likely you should
|
||||
# set `elementwise_affine` to False.
|
||||
elementwise_affine=True,
|
||||
eps=1e-5,
|
||||
bias=True,
|
||||
norm_type="layer_norm",
|
||||
):
|
||||
super().__init__()
|
||||
if norm_type == "layer_norm":
|
||||
self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
||||
elif norm_type == "rms_norm":
|
||||
self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
|
||||
else:
|
||||
raise ValueError(f"unknown norm_type {norm_type}")
|
||||
|
||||
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
|
||||
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
|
||||
shift, scale = torch.chunk(emb.flatten(1, 2).to(x.dtype), 2, dim=1)
|
||||
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
||||
return x
|
||||
|
||||
|
||||
class ChromaCombinedTimestepTextProjEmbeddings(nn.Module):
|
||||
def __init__(self, num_channels: int, out_dim: int):
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(num_channels=num_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.guidance_proj = Timesteps(num_channels=num_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
|
||||
self.register_buffer(
|
||||
"mod_proj",
|
||||
get_timestep_embedding(
|
||||
torch.arange(out_dim) * 1000, 2 * num_channels, flip_sin_to_cos=True, downscale_freq_shift=0
|
||||
),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def forward(self, timestep: torch.Tensor) -> torch.Tensor:
|
||||
mod_index_length = self.mod_proj.shape[0]
|
||||
batch_size = timestep.shape[0]
|
||||
|
||||
timesteps_proj = self.time_proj(timestep).to(dtype=timestep.dtype)
|
||||
guidance_proj = self.guidance_proj(torch.tensor([0] * batch_size)).to(
|
||||
dtype=timestep.dtype, device=timestep.device
|
||||
)
|
||||
|
||||
mod_proj = self.mod_proj.to(dtype=timesteps_proj.dtype, device=timesteps_proj.device).repeat(batch_size, 1, 1)
|
||||
timestep_guidance = (
|
||||
torch.cat([timesteps_proj, guidance_proj], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1)
|
||||
)
|
||||
input_vec = torch.cat([timestep_guidance, mod_proj], dim=-1)
|
||||
return input_vec.to(timestep.dtype)
|
||||
|
||||
|
||||
class ChromaApproximator(nn.Module):
|
||||
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers: int = 5):
|
||||
super().__init__()
|
||||
self.in_proj = nn.Linear(in_dim, hidden_dim, bias=True)
|
||||
self.layers = nn.ModuleList(
|
||||
[PixArtAlphaTextProjection(hidden_dim, hidden_dim, act_fn="silu") for _ in range(n_layers)]
|
||||
)
|
||||
self.norms = nn.ModuleList([nn.RMSNorm(hidden_dim) for _ in range(n_layers)])
|
||||
self.out_proj = nn.Linear(hidden_dim, out_dim)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.in_proj(x)
|
||||
|
||||
for layer, norms in zip(self.layers, self.norms):
|
||||
x = x + layer(norms(x))
|
||||
|
||||
return self.out_proj(x)
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class ChromaSingleTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.norm = ChromaAdaLayerNormZeroSinglePruned(dim)
|
||||
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
|
||||
self.act_mlp = nn.GELU(approximate="tanh")
|
||||
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
||||
|
||||
if is_torch_npu_available():
|
||||
deprecation_message = (
|
||||
"Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
|
||||
"should be set explicitly using the `set_attn_processor` method."
|
||||
)
|
||||
deprecate("npu_processor", "0.34.0", deprecation_message)
|
||||
processor = FluxAttnProcessor2_0_NPU()
|
||||
else:
|
||||
processor = FluxAttnProcessor2_0()
|
||||
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
bias=True,
|
||||
processor=processor,
|
||||
qk_norm="rms_norm",
|
||||
eps=1e-6,
|
||||
pre_only=True,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
||||
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||
attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
**joint_attention_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
||||
gate = gate.unsqueeze(1)
|
||||
hidden_states = gate * self.proj_out(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
if hidden_states.dtype == torch.float16:
|
||||
hidden_states = hidden_states.clip(-65504, 65504)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class ChromaTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
qk_norm: str = "rms_norm",
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = ChromaAdaLayerNormZeroPruned(dim)
|
||||
self.norm1_context = ChromaAdaLayerNormZeroPruned(dim)
|
||||
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
added_kv_proj_dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
context_pre_only=False,
|
||||
bias=True,
|
||||
processor=FluxAttnProcessor2_0(),
|
||||
qk_norm=qk_norm,
|
||||
eps=eps,
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
||||
|
||||
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
temb_img, temb_txt = temb[:, :6], temb[:, 6:]
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb_img)
|
||||
|
||||
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
||||
encoder_hidden_states, emb=temb_txt
|
||||
)
|
||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||
# Attention.
|
||||
attention_outputs = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
**joint_attention_kwargs,
|
||||
)
|
||||
|
||||
if len(attention_outputs) == 2:
|
||||
attn_output, context_attn_output = attention_outputs
|
||||
elif len(attention_outputs) == 3:
|
||||
attn_output, context_attn_output, ip_attn_output = attention_outputs
|
||||
|
||||
# Process attention outputs for the `hidden_states`.
|
||||
attn_output = gate_msa.unsqueeze(1) * attn_output
|
||||
hidden_states = hidden_states + attn_output
|
||||
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
||||
|
||||
hidden_states = hidden_states + ff_output
|
||||
if len(attention_outputs) == 3:
|
||||
hidden_states = hidden_states + ip_attn_output
|
||||
|
||||
# Process attention outputs for the `encoder_hidden_states`.
|
||||
|
||||
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
|
||||
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
||||
|
||||
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
||||
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
||||
|
||||
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
||||
if encoder_hidden_states.dtype == torch.float16:
|
||||
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
||||
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
|
||||
class ChromaTransformer2DModel(
|
||||
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
|
||||
):
|
||||
"""
|
||||
The Transformer model introduced in Flux, modified for Chroma.
|
||||
|
||||
Reference: https://huggingface.co/lodestones/Chroma
|
||||
|
||||
Args:
|
||||
patch_size (`int`, defaults to `1`):
|
||||
Patch size to turn the input data into small patches.
|
||||
in_channels (`int`, defaults to `64`):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, *optional*, defaults to `None`):
|
||||
The number of channels in the output. If not specified, it defaults to `in_channels`.
|
||||
num_layers (`int`, defaults to `19`):
|
||||
The number of layers of dual stream DiT blocks to use.
|
||||
num_single_layers (`int`, defaults to `38`):
|
||||
The number of layers of single stream DiT blocks to use.
|
||||
attention_head_dim (`int`, defaults to `128`):
|
||||
The number of dimensions to use for each attention head.
|
||||
num_attention_heads (`int`, defaults to `24`):
|
||||
The number of attention heads to use.
|
||||
joint_attention_dim (`int`, defaults to `4096`):
|
||||
The number of dimensions to use for the joint attention (embedding/channel dimension of
|
||||
`encoder_hidden_states`).
|
||||
axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
|
||||
The dimensions to use for the rotary positional embeddings.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"]
|
||||
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 1,
|
||||
in_channels: int = 64,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 19,
|
||||
num_single_layers: int = 38,
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 24,
|
||||
joint_attention_dim: int = 4096,
|
||||
axes_dims_rope: Tuple[int, ...] = (16, 56, 56),
|
||||
approximator_num_channels: int = 64,
|
||||
approximator_hidden_dim: int = 5120,
|
||||
approximator_layers: int = 5,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels or in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
|
||||
|
||||
self.time_text_embed = ChromaCombinedTimestepTextProjEmbeddings(
|
||||
num_channels=approximator_num_channels // 4,
|
||||
out_dim=3 * num_single_layers + 2 * 6 * num_layers + 2,
|
||||
)
|
||||
self.distilled_guidance_layer = ChromaApproximator(
|
||||
in_dim=approximator_num_channels,
|
||||
out_dim=self.inner_dim,
|
||||
hidden_dim=approximator_hidden_dim,
|
||||
n_layers=approximator_layers,
|
||||
)
|
||||
|
||||
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
|
||||
self.x_embedder = nn.Linear(in_channels, self.inner_dim)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
ChromaTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
ChromaSingleTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
)
|
||||
for _ in range(num_single_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm_out = ChromaAdaLayerNormContinuousPruned(
|
||||
self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6
|
||||
)
|
||||
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
||||
are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
self.set_attn_processor(FusedFluxAttnProcessor2_0())
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
img_ids: torch.Tensor = None,
|
||||
txt_ids: torch.Tensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
controlnet_block_samples=None,
|
||||
controlnet_single_block_samples=None,
|
||||
return_dict: bool = True,
|
||||
controlnet_blocks_repeat: bool = False,
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
The [`FluxTransformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
timestep ( `torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
||||
A list of tensors that if specified are added to the residuals of transformer blocks.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
timestep = timestep.to(hidden_states.dtype) * 1000
|
||||
|
||||
input_vec = self.time_text_embed(timestep)
|
||||
pooled_temb = self.distilled_guidance_layer(input_vec)
|
||||
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
if txt_ids.ndim == 3:
|
||||
logger.warning(
|
||||
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
||||
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
||||
)
|
||||
txt_ids = txt_ids[0]
|
||||
if img_ids.ndim == 3:
|
||||
logger.warning(
|
||||
"Passing `img_ids` 3d torch.Tensor is deprecated."
|
||||
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
||||
)
|
||||
img_ids = img_ids[0]
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
|
||||
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
|
||||
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
|
||||
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
|
||||
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
img_offset = 3 * len(self.single_transformer_blocks)
|
||||
txt_offset = img_offset + 6 * len(self.transformer_blocks)
|
||||
img_modulation = img_offset + 6 * index_block
|
||||
text_modulation = txt_offset + 6 * index_block
|
||||
temb = torch.cat(
|
||||
(
|
||||
pooled_temb[:, img_modulation : img_modulation + 6],
|
||||
pooled_temb[:, text_modulation : text_modulation + 6],
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
)
|
||||
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
)
|
||||
|
||||
# controlnet residual
|
||||
if controlnet_block_samples is not None:
|
||||
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
|
||||
interval_control = int(np.ceil(interval_control))
|
||||
# For Xlabs ControlNet.
|
||||
if controlnet_blocks_repeat:
|
||||
hidden_states = (
|
||||
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
|
||||
)
|
||||
else:
|
||||
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
start_idx = 3 * index_block
|
||||
temb = pooled_temb[:, start_idx : start_idx + 3]
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
)
|
||||
|
||||
# controlnet residual
|
||||
if controlnet_single_block_samples is not None:
|
||||
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
|
||||
interval_control = int(np.ceil(interval_control))
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
||||
+ controlnet_single_block_samples[index_block // interval_control]
|
||||
)
|
||||
|
||||
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
||||
|
||||
temb = pooled_temb[:, -2:]
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -100,11 +100,15 @@ class CosmosAdaLayerNorm(nn.Module):
|
||||
embedded_timestep = self.linear_2(embedded_timestep)
|
||||
|
||||
if temb is not None:
|
||||
embedded_timestep = embedded_timestep + temb[:, : 2 * self.embedding_dim]
|
||||
embedded_timestep = embedded_timestep + temb[..., : 2 * self.embedding_dim]
|
||||
|
||||
shift, scale = embedded_timestep.chunk(2, dim=1)
|
||||
shift, scale = embedded_timestep.chunk(2, dim=-1)
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = hidden_states * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
if embedded_timestep.ndim == 2:
|
||||
shift, scale = (x.unsqueeze(1) for x in (shift, scale))
|
||||
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -135,9 +139,13 @@ class CosmosAdaLayerNormZero(nn.Module):
|
||||
if temb is not None:
|
||||
embedded_timestep = embedded_timestep + temb
|
||||
|
||||
shift, scale, gate = embedded_timestep.chunk(3, dim=1)
|
||||
shift, scale, gate = embedded_timestep.chunk(3, dim=-1)
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = hidden_states * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
if embedded_timestep.ndim == 2:
|
||||
shift, scale, gate = (x.unsqueeze(1) for x in (shift, scale, gate))
|
||||
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
return hidden_states, gate
|
||||
|
||||
|
||||
@@ -255,19 +263,19 @@ class CosmosTransformerBlock(nn.Module):
|
||||
# 1. Self Attention
|
||||
norm_hidden_states, gate = self.norm1(hidden_states, embedded_timestep, temb)
|
||||
attn_output = self.attn1(norm_hidden_states, image_rotary_emb=image_rotary_emb)
|
||||
hidden_states = hidden_states + gate.unsqueeze(1) * attn_output
|
||||
hidden_states = hidden_states + gate * attn_output
|
||||
|
||||
# 2. Cross Attention
|
||||
norm_hidden_states, gate = self.norm2(hidden_states, embedded_timestep, temb)
|
||||
attn_output = self.attn2(
|
||||
norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
||||
)
|
||||
hidden_states = hidden_states + gate.unsqueeze(1) * attn_output
|
||||
hidden_states = hidden_states + gate * attn_output
|
||||
|
||||
# 3. Feed Forward
|
||||
norm_hidden_states, gate = self.norm3(hidden_states, embedded_timestep, temb)
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
hidden_states = hidden_states + gate.unsqueeze(1) * ff_output
|
||||
hidden_states = hidden_states + gate * ff_output
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -513,7 +521,23 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
hidden_states = hidden_states.flatten(1, 3) # [B, T, H, W, C] -> [B, THW, C]
|
||||
|
||||
# 4. Timestep embeddings
|
||||
temb, embedded_timestep = self.time_embed(hidden_states, timestep)
|
||||
if timestep.ndim == 1:
|
||||
temb, embedded_timestep = self.time_embed(hidden_states, timestep)
|
||||
elif timestep.ndim == 5:
|
||||
assert timestep.shape == (batch_size, 1, num_frames, 1, 1), (
|
||||
f"Expected timestep to have shape [B, 1, T, 1, 1], but got {timestep.shape}"
|
||||
)
|
||||
timestep = timestep.flatten()
|
||||
temb, embedded_timestep = self.time_embed(hidden_states, timestep)
|
||||
# We can do this because num_frames == post_patch_num_frames, as p_t is 1
|
||||
temb, embedded_timestep = (
|
||||
x.view(batch_size, post_patch_num_frames, 1, 1, -1)
|
||||
.expand(-1, -1, post_patch_height, post_patch_width, -1)
|
||||
.flatten(1, 3)
|
||||
for x in (temb, embedded_timestep)
|
||||
) # [BT, C] -> [B, T, 1, 1, C] -> [B, T, H, W, C] -> [B, THW, C]
|
||||
else:
|
||||
assert False
|
||||
|
||||
# 5. Transformer blocks
|
||||
for block in self.transformer_blocks:
|
||||
@@ -544,8 +568,8 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.unflatten(2, (p_h, p_w, p_t, -1))
|
||||
hidden_states = hidden_states.unflatten(1, (post_patch_num_frames, post_patch_height, post_patch_width))
|
||||
# Please just kill me at this point. What even is this permutation order and why is it different from the patching order?
|
||||
# Another few hours of sanity lost to the void.
|
||||
# NOTE: The permutation order here is not the inverse operation of what happens when patching as usually expected.
|
||||
# It might be a source of confusion to the reader, but this is correct
|
||||
hidden_states = hidden_states.permute(0, 7, 1, 6, 2, 4, 3, 5)
|
||||
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
|
||||
@@ -241,7 +241,7 @@ class FluxTransformer2DModel(
|
||||
joint_attention_dim: int = 4096,
|
||||
pooled_projection_dim: int = 768,
|
||||
guidance_embeds: bool = False,
|
||||
axes_dims_rope: Tuple[int] = (16, 56, 56),
|
||||
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels or in_channels
|
||||
|
||||
@@ -389,9 +389,7 @@ class MOEFeedForwardSwiGLU(nn.Module):
|
||||
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
||||
expert_cache = torch.zeros_like(x)
|
||||
idxs = flat_expert_indices.argsort()
|
||||
count_freq = torch.bincount(flat_expert_indices, minlength=self.num_activated_experts)
|
||||
tokens_per_expert = count_freq.cumsum(dim=0)
|
||||
|
||||
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
|
||||
token_idxs = idxs // self.num_activated_experts
|
||||
for i, end_idx in enumerate(tokens_per_expert):
|
||||
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
|
||||
|
||||
@@ -0,0 +1,393 @@
|
||||
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import FP32LayerNorm
|
||||
from .transformer_wan import WanAttnProcessor2_0, WanRotaryPosEmbed, WanTimeTextImageEmbedding, WanTransformerBlock
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class WanVACETransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
ffn_dim: int,
|
||||
num_heads: int,
|
||||
qk_norm: str = "rms_norm_across_heads",
|
||||
cross_attn_norm: bool = False,
|
||||
eps: float = 1e-6,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
apply_input_projection: bool = False,
|
||||
apply_output_projection: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# 1. Input projection
|
||||
self.proj_in = None
|
||||
if apply_input_projection:
|
||||
self.proj_in = nn.Linear(dim, dim)
|
||||
|
||||
# 2. Self-attention
|
||||
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_heads,
|
||||
kv_heads=num_heads,
|
||||
dim_head=dim // num_heads,
|
||||
qk_norm=qk_norm,
|
||||
eps=eps,
|
||||
bias=True,
|
||||
cross_attention_dim=None,
|
||||
out_bias=True,
|
||||
processor=WanAttnProcessor2_0(),
|
||||
)
|
||||
|
||||
# 3. Cross-attention
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_heads,
|
||||
kv_heads=num_heads,
|
||||
dim_head=dim // num_heads,
|
||||
qk_norm=qk_norm,
|
||||
eps=eps,
|
||||
bias=True,
|
||||
cross_attention_dim=None,
|
||||
out_bias=True,
|
||||
added_kv_proj_dim=added_kv_proj_dim,
|
||||
added_proj_bias=True,
|
||||
processor=WanAttnProcessor2_0(),
|
||||
)
|
||||
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
||||
|
||||
# 4. Feed-forward
|
||||
self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
|
||||
self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
|
||||
|
||||
# 5. Output projection
|
||||
self.proj_out = None
|
||||
if apply_output_projection:
|
||||
self.proj_out = nn.Linear(dim, dim)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
control_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
rotary_emb: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
if self.proj_in is not None:
|
||||
control_hidden_states = self.proj_in(control_hidden_states)
|
||||
control_hidden_states = control_hidden_states + hidden_states
|
||||
|
||||
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
|
||||
self.scale_shift_table + temb.float()
|
||||
).chunk(6, dim=1)
|
||||
|
||||
# 1. Self-attention
|
||||
norm_hidden_states = (self.norm1(control_hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(
|
||||
control_hidden_states
|
||||
)
|
||||
attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
|
||||
control_hidden_states = (control_hidden_states.float() + attn_output * gate_msa).type_as(control_hidden_states)
|
||||
|
||||
# 2. Cross-attention
|
||||
norm_hidden_states = self.norm2(control_hidden_states.float()).type_as(control_hidden_states)
|
||||
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||
control_hidden_states = control_hidden_states + attn_output
|
||||
|
||||
# 3. Feed-forward
|
||||
norm_hidden_states = (self.norm3(control_hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
|
||||
control_hidden_states
|
||||
)
|
||||
ff_output = self.ffn(norm_hidden_states)
|
||||
control_hidden_states = (control_hidden_states.float() + ff_output.float() * c_gate_msa).type_as(
|
||||
control_hidden_states
|
||||
)
|
||||
|
||||
conditioning_states = None
|
||||
if self.proj_out is not None:
|
||||
conditioning_states = self.proj_out(control_hidden_states)
|
||||
|
||||
return conditioning_states, control_hidden_states
|
||||
|
||||
|
||||
class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
|
||||
r"""
|
||||
A Transformer model for video-like data used in the Wan model.
|
||||
|
||||
Args:
|
||||
patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
|
||||
3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
|
||||
num_attention_heads (`int`, defaults to `40`):
|
||||
Fixed length for text embeddings.
|
||||
attention_head_dim (`int`, defaults to `128`):
|
||||
The number of channels in each head.
|
||||
in_channels (`int`, defaults to `16`):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, defaults to `16`):
|
||||
The number of channels in the output.
|
||||
text_dim (`int`, defaults to `512`):
|
||||
Input dimension for text embeddings.
|
||||
freq_dim (`int`, defaults to `256`):
|
||||
Dimension for sinusoidal time embeddings.
|
||||
ffn_dim (`int`, defaults to `13824`):
|
||||
Intermediate dimension in feed-forward network.
|
||||
num_layers (`int`, defaults to `40`):
|
||||
The number of layers of transformer blocks to use.
|
||||
window_size (`Tuple[int]`, defaults to `(-1, -1)`):
|
||||
Window size for local attention (-1 indicates global attention).
|
||||
cross_attn_norm (`bool`, defaults to `True`):
|
||||
Enable cross-attention normalization.
|
||||
qk_norm (`bool`, defaults to `True`):
|
||||
Enable query/key normalization.
|
||||
eps (`float`, defaults to `1e-6`):
|
||||
Epsilon value for normalization layers.
|
||||
add_img_emb (`bool`, defaults to `False`):
|
||||
Whether to use img_emb.
|
||||
added_kv_proj_dim (`int`, *optional*, defaults to `None`):
|
||||
The number of channels to use for the added key and value projections. If `None`, no projection is used.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_skip_layerwise_casting_patterns = ["patch_embedding", "vace_patch_embedding", "condition_embedder", "norm"]
|
||||
_no_split_modules = ["WanTransformerBlock", "WanVACETransformerBlock"]
|
||||
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
|
||||
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: Tuple[int] = (1, 2, 2),
|
||||
num_attention_heads: int = 40,
|
||||
attention_head_dim: int = 128,
|
||||
in_channels: int = 16,
|
||||
out_channels: int = 16,
|
||||
text_dim: int = 4096,
|
||||
freq_dim: int = 256,
|
||||
ffn_dim: int = 13824,
|
||||
num_layers: int = 40,
|
||||
cross_attn_norm: bool = True,
|
||||
qk_norm: Optional[str] = "rms_norm_across_heads",
|
||||
eps: float = 1e-6,
|
||||
image_dim: Optional[int] = None,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
rope_max_seq_len: int = 1024,
|
||||
pos_embed_seq_len: Optional[int] = None,
|
||||
vace_layers: List[int] = [0, 5, 10, 15, 20, 25, 30, 35],
|
||||
vace_in_channels: int = 96,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
out_channels = out_channels or in_channels
|
||||
|
||||
if max(vace_layers) >= num_layers:
|
||||
raise ValueError(f"VACE layers {vace_layers} exceed the number of transformer layers {num_layers}.")
|
||||
if 0 not in vace_layers:
|
||||
raise ValueError("VACE layers must include layer 0.")
|
||||
|
||||
# 1. Patch & position embedding
|
||||
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
|
||||
self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
|
||||
self.vace_patch_embedding = nn.Conv3d(vace_in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
# 2. Condition embeddings
|
||||
# image_embedding_dim=1280 for I2V model
|
||||
self.condition_embedder = WanTimeTextImageEmbedding(
|
||||
dim=inner_dim,
|
||||
time_freq_dim=freq_dim,
|
||||
time_proj_dim=inner_dim * 6,
|
||||
text_embed_dim=text_dim,
|
||||
image_embed_dim=image_dim,
|
||||
pos_embed_seq_len=pos_embed_seq_len,
|
||||
)
|
||||
|
||||
# 3. Transformer blocks
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
WanTransformerBlock(
|
||||
inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.vace_blocks = nn.ModuleList(
|
||||
[
|
||||
WanVACETransformerBlock(
|
||||
inner_dim,
|
||||
ffn_dim,
|
||||
num_attention_heads,
|
||||
qk_norm,
|
||||
cross_attn_norm,
|
||||
eps,
|
||||
added_kv_proj_dim,
|
||||
apply_input_projection=i == 0, # Layer 0 always has input projection and is in vace_layers
|
||||
apply_output_projection=True,
|
||||
)
|
||||
for i in range(len(vace_layers))
|
||||
]
|
||||
)
|
||||
|
||||
# 4. Output norm & projection
|
||||
self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
|
||||
self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_hidden_states_image: Optional[torch.Tensor] = None,
|
||||
control_hidden_states: torch.Tensor = None,
|
||||
control_hidden_states_scale: torch.Tensor = None,
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p_t, p_h, p_w = self.config.patch_size
|
||||
post_patch_num_frames = num_frames // p_t
|
||||
post_patch_height = height // p_h
|
||||
post_patch_width = width // p_w
|
||||
|
||||
if control_hidden_states_scale is None:
|
||||
control_hidden_states_scale = control_hidden_states.new_ones(len(self.config.vace_layers))
|
||||
control_hidden_states_scale = torch.unbind(control_hidden_states_scale)
|
||||
if len(control_hidden_states_scale) != len(self.config.vace_layers):
|
||||
raise ValueError(
|
||||
f"Length of `control_hidden_states_scale` {len(control_hidden_states_scale)} should be "
|
||||
f"equal to {len(self.config.vace_layers)}."
|
||||
)
|
||||
|
||||
# 1. Rotary position embedding
|
||||
rotary_emb = self.rope(hidden_states)
|
||||
|
||||
# 2. Patch embedding
|
||||
hidden_states = self.patch_embedding(hidden_states)
|
||||
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
||||
|
||||
control_hidden_states = self.vace_patch_embedding(control_hidden_states)
|
||||
control_hidden_states = control_hidden_states.flatten(2).transpose(1, 2)
|
||||
control_hidden_states_padding = control_hidden_states.new_zeros(
|
||||
batch_size, hidden_states.size(1) - control_hidden_states.size(1), control_hidden_states.size(2)
|
||||
)
|
||||
control_hidden_states = torch.cat([control_hidden_states, control_hidden_states_padding], dim=1)
|
||||
|
||||
# 3. Time embedding
|
||||
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
|
||||
timestep, encoder_hidden_states, encoder_hidden_states_image
|
||||
)
|
||||
timestep_proj = timestep_proj.unflatten(1, (6, -1))
|
||||
|
||||
# 4. Image embedding
|
||||
if encoder_hidden_states_image is not None:
|
||||
encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
|
||||
|
||||
# 5. Transformer blocks
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
# Prepare VACE hints
|
||||
control_hidden_states_list = []
|
||||
for i, block in enumerate(self.vace_blocks):
|
||||
conditioning_states, control_hidden_states = self._gradient_checkpointing_func(
|
||||
block, hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb
|
||||
)
|
||||
control_hidden_states_list.append((conditioning_states, control_hidden_states_scale[i]))
|
||||
control_hidden_states_list = control_hidden_states_list[::-1]
|
||||
|
||||
for i, block in enumerate(self.blocks):
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
|
||||
)
|
||||
if i in self.config.vace_layers:
|
||||
control_hint, scale = control_hidden_states_list.pop()
|
||||
hidden_states = hidden_states + control_hint * scale
|
||||
else:
|
||||
# Prepare VACE hints
|
||||
control_hidden_states_list = []
|
||||
for i, block in enumerate(self.vace_blocks):
|
||||
conditioning_states, control_hidden_states = block(
|
||||
hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb
|
||||
)
|
||||
control_hidden_states_list.append((conditioning_states, control_hidden_states_scale[i]))
|
||||
control_hidden_states_list = control_hidden_states_list[::-1]
|
||||
|
||||
for i, block in enumerate(self.blocks):
|
||||
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
|
||||
if i in self.config.vace_layers:
|
||||
control_hint, scale = control_hidden_states_list.pop()
|
||||
hidden_states = hidden_states + control_hint * scale
|
||||
|
||||
# 6. Output norm, projection & unpatchify
|
||||
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
|
||||
|
||||
# Move the shift and scale tensors to the same device as hidden_states.
|
||||
# When using multi-GPU inference via accelerate these will be on the
|
||||
# first device rather than the last device, which hidden_states ends up
|
||||
# on.
|
||||
shift = shift.to(hidden_states.device)
|
||||
scale = scale.to(hidden_states.device)
|
||||
|
||||
hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
||||
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -148,6 +148,7 @@ else:
|
||||
"AudioLDM2UNet2DConditionModel",
|
||||
]
|
||||
_import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"]
|
||||
_import_structure["chroma"] = ["ChromaPipeline"]
|
||||
_import_structure["cogvideo"] = [
|
||||
"CogVideoXPipeline",
|
||||
"CogVideoXImageToVideoPipeline",
|
||||
@@ -157,7 +158,12 @@ else:
|
||||
_import_structure["cogview3"] = ["CogView3PlusPipeline"]
|
||||
_import_structure["cogview4"] = ["CogView4Pipeline", "CogView4ControlPipeline"]
|
||||
_import_structure["consisid"] = ["ConsisIDPipeline"]
|
||||
_import_structure["cosmos"] = ["CosmosTextToWorldPipeline", "CosmosVideoToWorldPipeline"]
|
||||
_import_structure["cosmos"] = [
|
||||
"Cosmos2TextToImagePipeline",
|
||||
"CosmosTextToWorldPipeline",
|
||||
"CosmosVideoToWorldPipeline",
|
||||
"Cosmos2VideoToWorldPipeline",
|
||||
]
|
||||
_import_structure["controlnet"].extend(
|
||||
[
|
||||
"BlipDiffusionControlNetPipeline",
|
||||
@@ -371,7 +377,7 @@ else:
|
||||
"WuerstchenDecoderPipeline",
|
||||
"WuerstchenPriorPipeline",
|
||||
]
|
||||
_import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline", "WanVideoToVideoPipeline"]
|
||||
_import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline", "WanVideoToVideoPipeline", "WanVACEPipeline"]
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@@ -531,6 +537,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
)
|
||||
from .aura_flow import AuraFlowPipeline
|
||||
from .blip_diffusion import BlipDiffusionPipeline
|
||||
from .chroma import ChromaPipeline
|
||||
from .cogvideo import (
|
||||
CogVideoXFunControlPipeline,
|
||||
CogVideoXImageToVideoPipeline,
|
||||
@@ -559,7 +566,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionControlNetXSPipeline,
|
||||
StableDiffusionXLControlNetXSPipeline,
|
||||
)
|
||||
from .cosmos import CosmosTextToWorldPipeline, CosmosVideoToWorldPipeline
|
||||
from .cosmos import (
|
||||
Cosmos2TextToImagePipeline,
|
||||
Cosmos2VideoToWorldPipeline,
|
||||
CosmosTextToWorldPipeline,
|
||||
CosmosVideoToWorldPipeline,
|
||||
)
|
||||
from .deepfloyd_if import (
|
||||
IFImg2ImgPipeline,
|
||||
IFImg2ImgSuperResolutionPipeline,
|
||||
@@ -739,7 +751,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
UniDiffuserTextDecoder,
|
||||
)
|
||||
from .visualcloze import VisualClozeGenerationPipeline, VisualClozePipeline
|
||||
from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline
|
||||
from .wan import WanImageToVideoPipeline, WanPipeline, WanVACEPipeline, WanVideoToVideoPipeline
|
||||
from .wuerstchen import (
|
||||
WuerstchenCombinedPipeline,
|
||||
WuerstchenDecoderPipeline,
|
||||
|
||||
@@ -21,7 +21,7 @@ from ...image_processor import VaeImageProcessor
|
||||
from ...models import UVit2DModel, VQModel
|
||||
from ...schedulers import AmusedScheduler
|
||||
from ...utils import is_torch_xla_available, replace_example_docstring
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
@@ -47,7 +47,8 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class AmusedPipeline(DiffusionPipeline):
|
||||
class AmusedPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
|
||||
_last_supported_version = "0.33.1"
|
||||
image_processor: VaeImageProcessor
|
||||
vqvae: VQModel
|
||||
tokenizer: CLIPTokenizer
|
||||
|
||||
@@ -21,7 +21,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...models import UVit2DModel, VQModel
|
||||
from ...schedulers import AmusedScheduler
|
||||
from ...utils import is_torch_xla_available, replace_example_docstring
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
@@ -57,7 +57,8 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class AmusedImg2ImgPipeline(DiffusionPipeline):
|
||||
class AmusedImg2ImgPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
|
||||
_last_supported_version = "0.33.1"
|
||||
image_processor: VaeImageProcessor
|
||||
vqvae: VQModel
|
||||
tokenizer: CLIPTokenizer
|
||||
|
||||
@@ -22,7 +22,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...models import UVit2DModel, VQModel
|
||||
from ...schedulers import AmusedScheduler
|
||||
from ...utils import is_torch_xla_available, replace_example_docstring
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
@@ -65,7 +65,8 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class AmusedInpaintPipeline(DiffusionPipeline):
|
||||
class AmusedInpaintPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
|
||||
_last_supported_version = "0.33.1"
|
||||
image_processor: VaeImageProcessor
|
||||
vqvae: VQModel
|
||||
tokenizer: CLIPTokenizer
|
||||
|
||||
@@ -24,7 +24,7 @@ from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import AudioPipelineOutput, DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
@@ -57,7 +57,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class AudioLDMPipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
class AudioLDMPipeline(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin):
|
||||
r"""
|
||||
Pipeline for text-to-audio generation using AudioLDM.
|
||||
|
||||
@@ -81,6 +81,7 @@ class AudioLDMPipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
Vocoder of class `SpeechT5HifiGan`.
|
||||
"""
|
||||
|
||||
_last_supported_version = "0.33.1"
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -21,6 +21,7 @@ from ..configuration_utils import ConfigMixin
|
||||
from ..models.controlnets import ControlNetUnionModel
|
||||
from ..utils import is_sentencepiece_available
|
||||
from .aura_flow import AuraFlowPipeline
|
||||
from .chroma import ChromaPipeline
|
||||
from .cogview3 import CogView3PlusPipeline
|
||||
from .cogview4 import CogView4ControlPipeline, CogView4Pipeline
|
||||
from .controlnet import (
|
||||
@@ -143,6 +144,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("flux-controlnet", FluxControlNetPipeline),
|
||||
("lumina", LuminaPipeline),
|
||||
("lumina2", Lumina2Pipeline),
|
||||
("chroma", ChromaPipeline),
|
||||
("cogview3", CogView3PlusPipeline),
|
||||
("cogview4", CogView4Pipeline),
|
||||
("cogview4-control", CogView4ControlPipeline),
|
||||
|
||||
@@ -25,7 +25,7 @@ from ...utils import (
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput
|
||||
from .blip_image_processing import BlipImageProcessor
|
||||
from .modeling_blip2 import Blip2QFormerModel
|
||||
from .modeling_ctx_clip import ContextCLIPTextModel
|
||||
@@ -81,7 +81,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class BlipDiffusionPipeline(DiffusionPipeline):
|
||||
class BlipDiffusionPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for Zero-Shot Subject Driven Generation using Blip Diffusion.
|
||||
|
||||
@@ -107,6 +107,7 @@ class BlipDiffusionPipeline(DiffusionPipeline):
|
||||
Position of the context token in the text encoder.
|
||||
"""
|
||||
|
||||
_last_supported_version = "0.33.1"
|
||||
model_cpu_offload_seq = "qformer->text_encoder->unet->vae"
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_additional_imports = {}
|
||||
_import_structure = {"pipeline_output": ["ChromaPipelineOutput"]}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_chroma"] = ["ChromaPipeline"]
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_chroma import ChromaPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
for name, value in _additional_imports.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -0,0 +1,863 @@
|
||||
# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ChromaTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import ChromaPipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import ChromaPipeline
|
||||
|
||||
>>> pipe = ChromaPipeline.from_single_file(
|
||||
... "chroma-unlocked-v35-detail-calibrated.safetensors", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
>>> prompt = "A cat holding a sign that says hello world"
|
||||
>>> image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0]
|
||||
>>> image.save("chroma.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
max_seq_len: int = 4096,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class ChromaPipeline(
|
||||
DiffusionPipeline,
|
||||
FluxLoraLoaderMixin,
|
||||
FromSingleFileMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
FluxIPAdapterMixin,
|
||||
):
|
||||
r"""
|
||||
The Chroma pipeline for text-to-image generation.
|
||||
|
||||
Reference: https://huggingface.co/lodestones/Chroma/
|
||||
|
||||
Args:
|
||||
transformer ([`ChromaTransformer2DModel`]):
|
||||
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representation
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
||||
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
||||
tokenizer (`T5TokenizerFast`):
|
||||
Second Tokenizer of class
|
||||
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae"
|
||||
_optional_components = ["image_encoder", "feature_extractor"]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: T5EncoderModel,
|
||||
tokenizer: T5TokenizerFast,
|
||||
transformer: ChromaTransformer2DModel,
|
||||
image_encoder: CLIPVisionModelWithProjection = None,
|
||||
feature_extractor: CLIPImageProcessor = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
image_encoder=image_encoder,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
||||
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||
self.default_sample_size = 128
|
||||
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
max_sequence_length: int = 512,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_length=False,
|
||||
return_overflowing_tokens=False,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
attention_mask = text_inputs.attention_mask.clone()
|
||||
|
||||
# Chroma requires the attention mask to include one padding token
|
||||
seq_lengths = attention_mask.sum(dim=1)
|
||||
mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1)
|
||||
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).long()
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device)
|
||||
)[0]
|
||||
|
||||
dtype = self.text_encoder.dtype
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
|
||||
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Union[str, List[str]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
max_sequence_length: int = 512,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
|
||||
instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
lora_scale (`float`, *optional*):
|
||||
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
||||
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
||||
negative_text_ids = None
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
if negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = (
|
||||
batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
)
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
negative_text_ids = torch.zeros(negative_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
||||
|
||||
if self.text_encoder is not None:
|
||||
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, text_ids, negative_prompt_embeds, negative_text_ids
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
|
||||
def encode_image(self, image, device, num_images_per_prompt):
|
||||
dtype = next(self.image_encoder.parameters()).dtype
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_embeds = self.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
return image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds
|
||||
def prepare_ip_adapter_image_embeds(
|
||||
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
|
||||
):
|
||||
image_embeds = []
|
||||
if ip_adapter_image_embeds is None:
|
||||
if not isinstance(ip_adapter_image, list):
|
||||
ip_adapter_image = [ip_adapter_image]
|
||||
|
||||
if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
|
||||
)
|
||||
|
||||
for single_ip_adapter_image in ip_adapter_image:
|
||||
single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
|
||||
image_embeds.append(single_image_embeds[None, :])
|
||||
else:
|
||||
if not isinstance(ip_adapter_image_embeds, list):
|
||||
ip_adapter_image_embeds = [ip_adapter_image_embeds]
|
||||
|
||||
if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
|
||||
)
|
||||
|
||||
for single_image_embeds in ip_adapter_image_embeds:
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
ip_adapter_image_embeds = []
|
||||
for single_image_embeds in image_embeds:
|
||||
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||
single_image_embeds = single_image_embeds.to(device=device)
|
||||
ip_adapter_image_embeds.append(single_image_embeds)
|
||||
|
||||
return ip_adapter_image_embeds
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
max_sequence_length=None,
|
||||
):
|
||||
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
|
||||
logger.warning(
|
||||
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if max_sequence_length is not None and max_sequence_length > 512:
|
||||
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
||||
|
||||
@staticmethod
|
||||
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
||||
latent_image_ids = torch.zeros(height, width, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
|
||||
|
||||
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
||||
|
||||
latent_image_ids = latent_image_ids.reshape(
|
||||
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
||||
)
|
||||
|
||||
return latent_image_ids.to(device=device, dtype=dtype)
|
||||
|
||||
@staticmethod
|
||||
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
||||
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
||||
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
||||
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def _unpack_latents(latents, height, width, vae_scale_factor):
|
||||
batch_size, num_patches, channels = latents.shape
|
||||
|
||||
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||
# latent height and width to be divisible by 2.
|
||||
height = 2 * (int(height) // (vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (vae_scale_factor * 2))
|
||||
|
||||
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
|
||||
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
||||
|
||||
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
|
||||
|
||||
return latents
|
||||
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.vae.enable_slicing()
|
||||
|
||||
def disable_vae_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_slicing()
|
||||
|
||||
def enable_vae_tiling(self):
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
self.vae.enable_tiling()
|
||||
|
||||
def disable_vae_tiling(self):
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||
# latent height and width to be divisible by 2.
|
||||
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
|
||||
if latents is not None:
|
||||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
||||
return latents.to(device=device, dtype=dtype), latent_image_ids
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
||||
|
||||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
||||
|
||||
return latents, latent_image_ids
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def joint_attention_kwargs(self):
|
||||
return self._joint_attention_kwargs
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Union[str, List[str]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 28,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
guidance_scale: float = 3.5,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
||||
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
not greater than `1`).
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
guidance_scale (`float`, *optional*, defaults to 3.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
||||
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
||||
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
||||
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
|
||||
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
||||
negative_ip_adapter_image:
|
||||
(`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
||||
negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
||||
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
||||
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
|
||||
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.flux.ChromaPipelineOutput`] instead of a plain tuple.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.chroma.ChromaPipelineOutput`] or `tuple`: [`~pipelines.chroma.ChromaPipelineOutput`] if
|
||||
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
|
||||
generated images.
|
||||
"""
|
||||
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._joint_attention_kwargs = joint_attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
lora_scale = (
|
||||
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
||||
)
|
||||
(
|
||||
prompt_embeds,
|
||||
text_ids,
|
||||
negative_prompt_embeds,
|
||||
negative_text_ids,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
latents, latent_image_ids = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
# 5. Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||
image_seq_len = latents.shape[1]
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.scheduler.config.get("base_image_seq_len", 256),
|
||||
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||
self.scheduler.config.get("base_shift", 0.5),
|
||||
self.scheduler.config.get("max_shift", 1.15),
|
||||
)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
|
||||
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
|
||||
):
|
||||
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
|
||||
negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
|
||||
|
||||
elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
|
||||
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
|
||||
):
|
||||
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
|
||||
ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
|
||||
|
||||
if self.joint_attention_kwargs is None:
|
||||
self._joint_attention_kwargs = {}
|
||||
|
||||
image_embeds = None
|
||||
negative_image_embeds = None
|
||||
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
||||
image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
ip_adapter_image,
|
||||
ip_adapter_image_embeds,
|
||||
device,
|
||||
batch_size * num_images_per_prompt,
|
||||
)
|
||||
if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
|
||||
negative_image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
negative_ip_adapter_image,
|
||||
negative_ip_adapter_image_embeds,
|
||||
device,
|
||||
batch_size * num_images_per_prompt,
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
if image_embeds is not None:
|
||||
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latents,
|
||||
timestep=timestep / 1000,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
if negative_image_embeds is not None:
|
||||
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
|
||||
neg_noise_pred = self.transformer(
|
||||
hidden_states=latents,
|
||||
timestep=timestep / 1000,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
txt_ids=negative_text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
else:
|
||||
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
||||
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ChromaPipelineOutput(images=image)
|
||||
@@ -0,0 +1,21 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
|
||||
from ...utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChromaPipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for Stable Diffusion pipelines.
|
||||
|
||||
Args:
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
||||
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
||||
"""
|
||||
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
@@ -37,7 +37,7 @@ from ...utils import (
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
@@ -98,6 +98,7 @@ EXAMPLE_DOC_STRING = """
|
||||
|
||||
|
||||
class StableDiffusionControlNetXSPipeline(
|
||||
DeprecatedPipelineMixin,
|
||||
DiffusionPipeline,
|
||||
StableDiffusionMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
@@ -138,6 +139,7 @@ class StableDiffusionControlNetXSPipeline(
|
||||
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||
"""
|
||||
|
||||
_last_supported_version = "0.33.1"
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
@@ -46,7 +46,7 @@ from ...utils import (
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline
|
||||
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
||||
|
||||
|
||||
@@ -114,6 +114,7 @@ EXAMPLE_DOC_STRING = """
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetXSPipeline(
|
||||
DeprecatedPipelineMixin,
|
||||
DiffusionPipeline,
|
||||
TextualInversionLoaderMixin,
|
||||
StableDiffusionXLLoraLoaderMixin,
|
||||
@@ -158,6 +159,7 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
watermarker is used.
|
||||
"""
|
||||
|
||||
_last_supported_version = "0.33.1"
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
|
||||
_optional_components = [
|
||||
"tokenizer",
|
||||
|
||||
@@ -22,6 +22,8 @@ except OptionalDependencyNotAvailable:
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_cosmos2_text2image"] = ["Cosmos2TextToImagePipeline"]
|
||||
_import_structure["pipeline_cosmos2_video2world"] = ["Cosmos2VideoToWorldPipeline"]
|
||||
_import_structure["pipeline_cosmos_text2world"] = ["CosmosTextToWorldPipeline"]
|
||||
_import_structure["pipeline_cosmos_video2world"] = ["CosmosVideoToWorldPipeline"]
|
||||
|
||||
@@ -33,6 +35,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_cosmos2_text2image import Cosmos2TextToImagePipeline
|
||||
from .pipeline_cosmos2_video2world import Cosmos2VideoToWorldPipeline
|
||||
from .pipeline_cosmos_text2world import CosmosTextToWorldPipeline
|
||||
from .pipeline_cosmos_video2world import CosmosVideoToWorldPipeline
|
||||
|
||||
|
||||
@@ -0,0 +1,673 @@
|
||||
# Copyright 2025 The NVIDIA Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...models import AutoencoderKLWan, CosmosTransformer3DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import CosmosImagePipelineOutput
|
||||
|
||||
|
||||
if is_cosmos_guardrail_available():
|
||||
from cosmos_guardrail import CosmosSafetyChecker
|
||||
else:
|
||||
|
||||
class CosmosSafetyChecker:
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise ImportError(
|
||||
"`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`."
|
||||
)
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import Cosmos2TextToImagePipeline
|
||||
|
||||
>>> # Available checkpoints: nvidia/Cosmos-Predict2-2B-Text2Image, nvidia/Cosmos-Predict2-14B-Text2Image
|
||||
>>> model_id = "nvidia/Cosmos-Predict2-2B-Text2Image"
|
||||
>>> pipe = Cosmos2TextToImagePipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> prompt = "A close-up shot captures a vibrant yellow scrubber vigorously working on a grimy plate, its bristles moving in circular motions to lift stubborn grease and food residue. The dish, once covered in remnants of a hearty meal, gradually reveals its original glossy surface. Suds form and bubble around the scrubber, creating a satisfying visual of cleanliness in progress. The sound of scrubbing fills the air, accompanied by the gentle clinking of the dish against the sink. As the scrubber continues its task, the dish transforms, gleaming under the bright kitchen lights, symbolizing the triumph of cleanliness over mess."
|
||||
>>> negative_prompt = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality."
|
||||
|
||||
>>> output = pipe(
|
||||
... prompt=prompt, negative_prompt=negative_prompt, generator=torch.Generator().manual_seed(1)
|
||||
... ).images[0]
|
||||
>>> output.save("output.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class Cosmos2TextToImagePipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using [Cosmos Predict2](https://github.com/nvidia-cosmos/cosmos-predict2).
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
Args:
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
Frozen text-encoder. Cosmos uses
|
||||
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
|
||||
[t5-11b](https://huggingface.co/google-t5/t5-11b) variant.
|
||||
tokenizer (`T5TokenizerFast`):
|
||||
Tokenizer of class
|
||||
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
||||
transformer ([`CosmosTransformer3DModel`]):
|
||||
Conditional Transformer to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLWan`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
# We mark safety_checker as optional here to get around some test failures, but it is not really optional
|
||||
_optional_components = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_encoder: T5EncoderModel,
|
||||
tokenizer: T5TokenizerFast,
|
||||
transformer: CosmosTransformer3DModel,
|
||||
vae: AutoencoderKLWan,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
safety_checker: CosmosSafetyChecker = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if safety_checker is None:
|
||||
safety_checker = CosmosSafetyChecker()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
)
|
||||
|
||||
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
|
||||
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
self.sigma_max = 80.0
|
||||
self.sigma_min = 0.002
|
||||
self.sigma_data = 1.0
|
||||
self.final_sigmas_type = "sigma_min"
|
||||
if self.scheduler is not None:
|
||||
self.scheduler.register_to_config(
|
||||
sigma_max=self.sigma_max,
|
||||
sigma_min=self.sigma_min,
|
||||
sigma_data=self.sigma_data,
|
||||
final_sigmas_type=self.final_sigmas_type,
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline._get_t5_prompt_embeds
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
max_sequence_length: int = 512,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
return_length=True,
|
||||
return_offsets_mapping=False,
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
prompt_attention_mask = text_inputs.attention_mask.bool().to(device)
|
||||
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(device), attention_mask=prompt_attention_mask
|
||||
).last_hidden_state
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
lengths = prompt_attention_mask.sum(dim=1).cpu()
|
||||
for i, length in enumerate(lengths):
|
||||
prompt_embeds[i, length:] = 0
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.encode_prompt with num_videos_per_prompt->num_images_per_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 512,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier free guidance or not.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = negative_prompt_embeds.shape
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_channels_latents: 16,
|
||||
height: int = 768,
|
||||
width: int = 1360,
|
||||
num_frames: int = 1,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype) * self.scheduler.config.sigma_max
|
||||
|
||||
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
latent_height = height // self.vae_scale_factor_spatial
|
||||
latent_width = width // self.vae_scale_factor_spatial
|
||||
shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
return latents * self.scheduler.config.sigma_max
|
||||
|
||||
# Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: int = 768,
|
||||
width: int = 1360,
|
||||
num_inference_steps: int = 35,
|
||||
guidance_scale: float = 7.0,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
height (`int`, defaults to `768`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, defaults to `1360`):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, defaults to `35`):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, defaults to `7.0`):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion
|
||||
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
`guidance_scale > 1`.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor is generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
|
||||
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`CosmosImagePipelineOutput`] instead of a plain tuple.
|
||||
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
||||
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
||||
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
||||
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
||||
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~CosmosImagePipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`CosmosImagePipelineOutput`] is returned, otherwise a `tuple` is returned
|
||||
where the first element is a list with the generated images and the second element is a list of `bool`s
|
||||
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
|
||||
if self.safety_checker is None:
|
||||
raise ValueError(
|
||||
f"You have disabled the safety checker for {self.__class__}. This is in violation of the "
|
||||
"[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). "
|
||||
f"Please ensure that you are compliant with the license agreement."
|
||||
)
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
num_frames = 1
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
if self.safety_checker is not None:
|
||||
self.safety_checker.to(device)
|
||||
if prompt is not None:
|
||||
prompt_list = [prompt] if isinstance(prompt, str) else prompt
|
||||
for p in prompt_list:
|
||||
if not self.safety_checker.check_text_safety(p):
|
||||
raise ValueError(
|
||||
f"Cosmos Guardrail detected unsafe text in the prompt: {p}. Please ensure that the "
|
||||
f"prompt abides by the NVIDIA Open Model License Agreement."
|
||||
)
|
||||
self.safety_checker.to("cpu")
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# 3. Encode input prompt
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
device=device,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
sigmas_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
||||
sigmas = torch.linspace(0, 1, num_inference_steps, dtype=sigmas_dtype)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, device=device, sigmas=sigmas)
|
||||
if self.scheduler.config.get("final_sigmas_type", "zero") == "sigma_min":
|
||||
# Replace the last sigma (which is zero) with the minimum sigma value
|
||||
self.scheduler.sigmas[-1] = self.scheduler.sigmas[-2]
|
||||
|
||||
# 5. Prepare latent variables
|
||||
transformer_dtype = self.transformer.dtype
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
|
||||
|
||||
# 6. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
current_sigma = self.scheduler.sigmas[i]
|
||||
|
||||
current_t = current_sigma / (current_sigma + 1)
|
||||
c_in = 1 - current_t
|
||||
c_skip = 1 - current_t
|
||||
c_out = -current_t
|
||||
timestep = current_t.expand(latents.shape[0]).to(transformer_dtype) # [B, 1, T, 1, 1]
|
||||
|
||||
latent_model_input = latents * c_in
|
||||
latent_model_input = latent_model_input.to(transformer_dtype)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
padding_mask=padding_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = (c_skip * latents + c_out * noise_pred.float()).to(transformer_dtype)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
padding_mask=padding_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred_uncond = (c_skip * latents + c_out * noise_pred_uncond.float()).to(transformer_dtype)
|
||||
noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_uncond)
|
||||
|
||||
noise_pred = (latents - noise_pred) / current_sigma
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = latents / latents_std / self.scheduler.config.sigma_data + latents_mean
|
||||
video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
|
||||
|
||||
if self.safety_checker is not None:
|
||||
self.safety_checker.to(device)
|
||||
video = self.video_processor.postprocess_video(video, output_type="np")
|
||||
video = (video * 255).astype(np.uint8)
|
||||
video_batch = []
|
||||
for vid in video:
|
||||
vid = self.safety_checker.check_video_safety(vid)
|
||||
video_batch.append(vid)
|
||||
video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1
|
||||
video = torch.from_numpy(video).permute(0, 4, 1, 2, 3)
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
self.safety_checker.to("cpu")
|
||||
else:
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
image = [batch[0] for batch in video]
|
||||
if isinstance(video, torch.Tensor):
|
||||
image = torch.stack(image)
|
||||
elif isinstance(video, np.ndarray):
|
||||
image = np.stack(image)
|
||||
else:
|
||||
image = latents[:, :, 0]
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return CosmosImagePipelineOutput(images=image)
|
||||
@@ -0,0 +1,792 @@
|
||||
# Copyright 2025 The NVIDIA Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...models import AutoencoderKLWan, CosmosTransformer3DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import CosmosPipelineOutput
|
||||
|
||||
|
||||
if is_cosmos_guardrail_available():
|
||||
from cosmos_guardrail import CosmosSafetyChecker
|
||||
else:
|
||||
|
||||
class CosmosSafetyChecker:
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise ImportError(
|
||||
"`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`."
|
||||
)
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import Cosmos2VideoToWorldPipeline
|
||||
>>> from diffusers.utils import export_to_video, load_image
|
||||
|
||||
>>> # Available checkpoints: nvidia/Cosmos-Predict2-2B-Video2World, nvidia/Cosmos-Predict2-14B-Video2World
|
||||
>>> model_id = "nvidia/Cosmos-Predict2-2B-Video2World"
|
||||
>>> pipe = Cosmos2VideoToWorldPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> prompt = "A close-up shot captures a vibrant yellow scrubber vigorously working on a grimy plate, its bristles moving in circular motions to lift stubborn grease and food residue. The dish, once covered in remnants of a hearty meal, gradually reveals its original glossy surface. Suds form and bubble around the scrubber, creating a satisfying visual of cleanliness in progress. The sound of scrubbing fills the air, accompanied by the gentle clinking of the dish against the sink. As the scrubber continues its task, the dish transforms, gleaming under the bright kitchen lights, symbolizing the triumph of cleanliness over mess."
|
||||
>>> negative_prompt = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality."
|
||||
>>> image = load_image(
|
||||
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yellow-scrubber.png"
|
||||
... )
|
||||
|
||||
>>> video = pipe(
|
||||
... image=image, prompt=prompt, negative_prompt=negative_prompt, generator=torch.Generator().manual_seed(1)
|
||||
... ).frames[0]
|
||||
>>> export_to_video(video, "output.mp4", fps=16)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class Cosmos2VideoToWorldPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for video-to-world generation using [Cosmos Predict2](https://github.com/nvidia-cosmos/cosmos-predict2).
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
Args:
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
Frozen text-encoder. Cosmos uses
|
||||
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
|
||||
[t5-11b](https://huggingface.co/google-t5/t5-11b) variant.
|
||||
tokenizer (`T5TokenizerFast`):
|
||||
Tokenizer of class
|
||||
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
||||
transformer ([`CosmosTransformer3DModel`]):
|
||||
Conditional Transformer to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLWan`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
# We mark safety_checker as optional here to get around some test failures, but it is not really optional
|
||||
_optional_components = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_encoder: T5EncoderModel,
|
||||
tokenizer: T5TokenizerFast,
|
||||
transformer: CosmosTransformer3DModel,
|
||||
vae: AutoencoderKLWan,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
safety_checker: CosmosSafetyChecker = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if safety_checker is None:
|
||||
safety_checker = CosmosSafetyChecker()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
)
|
||||
|
||||
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
|
||||
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
self.sigma_max = 80.0
|
||||
self.sigma_min = 0.002
|
||||
self.sigma_data = 1.0
|
||||
self.final_sigmas_type = "sigma_min"
|
||||
if self.scheduler is not None:
|
||||
self.scheduler.register_to_config(
|
||||
sigma_max=self.sigma_max,
|
||||
sigma_min=self.sigma_min,
|
||||
sigma_data=self.sigma_data,
|
||||
final_sigmas_type=self.final_sigmas_type,
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline._get_t5_prompt_embeds
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
max_sequence_length: int = 512,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
return_length=True,
|
||||
return_offsets_mapping=False,
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
prompt_attention_mask = text_inputs.attention_mask.bool().to(device)
|
||||
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(device), attention_mask=prompt_attention_mask
|
||||
).last_hidden_state
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
lengths = prompt_attention_mask.sum(dim=1).cpu()
|
||||
for i, length in enumerate(lengths):
|
||||
prompt_embeds[i, length:] = 0
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 512,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier free guidance or not.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = negative_prompt_embeds.shape
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
video: torch.Tensor,
|
||||
batch_size: int,
|
||||
num_channels_latents: 16,
|
||||
height: int = 704,
|
||||
width: int = 1280,
|
||||
num_frames: int = 93,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
num_cond_frames = video.size(2)
|
||||
if num_cond_frames >= num_frames:
|
||||
# Take the last `num_frames` frames for conditioning
|
||||
num_cond_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
video = video[:, :, -num_frames:]
|
||||
else:
|
||||
num_cond_latent_frames = (num_cond_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
num_padding_frames = num_frames - num_cond_frames
|
||||
last_frame = video[:, :, -1:]
|
||||
padding = last_frame.repeat(1, 1, num_padding_frames, 1, 1)
|
||||
video = torch.cat([video, padding], dim=2)
|
||||
|
||||
if isinstance(generator, list):
|
||||
init_latents = [
|
||||
retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator=generator[i])
|
||||
for i in range(batch_size)
|
||||
]
|
||||
else:
|
||||
init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
|
||||
|
||||
init_latents = torch.cat(init_latents, dim=0).to(dtype)
|
||||
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype)
|
||||
)
|
||||
latents_std = (
|
||||
torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype)
|
||||
)
|
||||
init_latents = (init_latents - latents_mean) / latents_std * self.scheduler.config.sigma_data
|
||||
|
||||
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
latent_height = height // self.vae_scale_factor_spatial
|
||||
latent_width = width // self.vae_scale_factor_spatial
|
||||
shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
|
||||
latents = latents * self.scheduler.config.sigma_max
|
||||
|
||||
padding_shape = (batch_size, 1, num_latent_frames, latent_height, latent_width)
|
||||
ones_padding = latents.new_ones(padding_shape)
|
||||
zeros_padding = latents.new_zeros(padding_shape)
|
||||
|
||||
cond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1)
|
||||
cond_indicator[:, :, :num_cond_latent_frames] = 1.0
|
||||
cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding
|
||||
|
||||
uncond_indicator = uncond_mask = None
|
||||
if do_classifier_free_guidance:
|
||||
uncond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1)
|
||||
uncond_indicator[:, :, :num_cond_latent_frames] = 1.0
|
||||
uncond_mask = uncond_indicator * ones_padding + (1 - uncond_indicator) * zeros_padding
|
||||
|
||||
return latents, init_latents, cond_indicator, uncond_indicator, cond_mask, uncond_mask
|
||||
|
||||
# Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
image: PipelineImageInput = None,
|
||||
video: List[PipelineImageInput] = None,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: int = 704,
|
||||
width: int = 1280,
|
||||
num_frames: int = 93,
|
||||
num_inference_steps: int = 35,
|
||||
guidance_scale: float = 7.0,
|
||||
fps: int = 16,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
sigma_conditioning: float = 0.0001,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, *optional*):
|
||||
The image to be used as a conditioning input for the video generation.
|
||||
video (`List[PIL.Image.Image]`, `np.ndarray`, `torch.Tensor`, *optional*):
|
||||
The video to be used as a conditioning input for the video generation.
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
height (`int`, defaults to `704`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, defaults to `1280`):
|
||||
The width in pixels of the generated image.
|
||||
num_frames (`int`, defaults to `93`):
|
||||
The number of frames in the generated video.
|
||||
num_inference_steps (`int`, defaults to `35`):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, defaults to `7.0`):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion
|
||||
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
`guidance_scale > 1`.
|
||||
fps (`int`, defaults to `16`):
|
||||
The frames per second of the generated video.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor is generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
|
||||
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple.
|
||||
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
||||
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
||||
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
||||
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
||||
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int`, defaults to `512`):
|
||||
The maximum number of tokens in the prompt. If the prompt exceeds this length, it will be truncated. If
|
||||
the prompt is shorter than this length, it will be padded.
|
||||
sigma_conditioning (`float`, defaults to `0.0001`):
|
||||
The sigma value used for scaling conditioning latents. Ideally, it should not be changed or should be
|
||||
set to a small value close to zero.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~CosmosPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`CosmosPipelineOutput`] is returned, otherwise a `tuple` is returned where
|
||||
the first element is a list with the generated images and the second element is a list of `bool`s
|
||||
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
|
||||
if self.safety_checker is None:
|
||||
raise ValueError(
|
||||
f"You have disabled the safety checker for {self.__class__}. This is in violation of the "
|
||||
"[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). "
|
||||
f"Please ensure that you are compliant with the license agreement."
|
||||
)
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
if self.safety_checker is not None:
|
||||
self.safety_checker.to(device)
|
||||
if prompt is not None:
|
||||
prompt_list = [prompt] if isinstance(prompt, str) else prompt
|
||||
for p in prompt_list:
|
||||
if not self.safety_checker.check_text_safety(p):
|
||||
raise ValueError(
|
||||
f"Cosmos Guardrail detected unsafe text in the prompt: {p}. Please ensure that the "
|
||||
f"prompt abides by the NVIDIA Open Model License Agreement."
|
||||
)
|
||||
self.safety_checker.to("cpu")
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# 3. Encode input prompt
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
device=device,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
sigmas_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
||||
sigmas = torch.linspace(0, 1, num_inference_steps, dtype=sigmas_dtype)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, device=device, sigmas=sigmas)
|
||||
if self.scheduler.config.final_sigmas_type == "sigma_min":
|
||||
# Replace the last sigma (which is zero) with the minimum sigma value
|
||||
self.scheduler.sigmas[-1] = self.scheduler.sigmas[-2]
|
||||
|
||||
# 5. Prepare latent variables
|
||||
vae_dtype = self.vae.dtype
|
||||
transformer_dtype = self.transformer.dtype
|
||||
|
||||
if image is not None:
|
||||
video = self.video_processor.preprocess(image, height, width).unsqueeze(2)
|
||||
else:
|
||||
video = self.video_processor.preprocess_video(video, height, width)
|
||||
video = video.to(device=device, dtype=vae_dtype)
|
||||
|
||||
num_channels_latents = self.transformer.config.in_channels - 1
|
||||
latents, conditioning_latents, cond_indicator, uncond_indicator, cond_mask, uncond_mask = self.prepare_latents(
|
||||
video,
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
self.do_classifier_free_guidance,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
unconditioning_latents = None
|
||||
|
||||
cond_mask = cond_mask.to(transformer_dtype)
|
||||
if self.do_classifier_free_guidance:
|
||||
uncond_mask = uncond_mask.to(transformer_dtype)
|
||||
unconditioning_latents = conditioning_latents
|
||||
|
||||
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
|
||||
sigma_conditioning = torch.tensor(sigma_conditioning, dtype=torch.float32, device=device)
|
||||
t_conditioning = sigma_conditioning / (sigma_conditioning + 1)
|
||||
|
||||
# 6. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
current_sigma = self.scheduler.sigmas[i]
|
||||
|
||||
current_t = current_sigma / (current_sigma + 1)
|
||||
c_in = 1 - current_t
|
||||
c_skip = 1 - current_t
|
||||
c_out = -current_t
|
||||
timestep = current_t.view(1, 1, 1, 1, 1).expand(
|
||||
latents.size(0), -1, latents.size(2), -1, -1
|
||||
) # [B, 1, T, 1, 1]
|
||||
|
||||
cond_latent = latents * c_in
|
||||
cond_latent = cond_indicator * conditioning_latents + (1 - cond_indicator) * cond_latent
|
||||
cond_latent = cond_latent.to(transformer_dtype)
|
||||
cond_timestep = cond_indicator * t_conditioning + (1 - cond_indicator) * timestep
|
||||
cond_timestep = cond_timestep.to(transformer_dtype)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=cond_latent,
|
||||
timestep=cond_timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
fps=fps,
|
||||
condition_mask=cond_mask,
|
||||
padding_mask=padding_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = (c_skip * latents + c_out * noise_pred.float()).to(transformer_dtype)
|
||||
noise_pred = cond_indicator * conditioning_latents + (1 - cond_indicator) * noise_pred
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
uncond_latent = latents * c_in
|
||||
uncond_latent = uncond_indicator * unconditioning_latents + (1 - uncond_indicator) * uncond_latent
|
||||
uncond_latent = uncond_latent.to(transformer_dtype)
|
||||
uncond_timestep = uncond_indicator * t_conditioning + (1 - uncond_indicator) * timestep
|
||||
uncond_timestep = uncond_timestep.to(transformer_dtype)
|
||||
|
||||
noise_pred_uncond = self.transformer(
|
||||
hidden_states=uncond_latent,
|
||||
timestep=uncond_timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
fps=fps,
|
||||
condition_mask=uncond_mask,
|
||||
padding_mask=padding_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred_uncond = (c_skip * latents + c_out * noise_pred_uncond.float()).to(transformer_dtype)
|
||||
noise_pred_uncond = (
|
||||
uncond_indicator * unconditioning_latents + (1 - uncond_indicator) * noise_pred_uncond
|
||||
)
|
||||
noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_uncond)
|
||||
|
||||
noise_pred = (latents - noise_pred) / current_sigma
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = (
|
||||
torch.tensor(self.vae.config.latents_std)
|
||||
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents = latents * latents_std / self.scheduler.config.sigma_data + latents_mean
|
||||
video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
|
||||
|
||||
if self.safety_checker is not None:
|
||||
self.safety_checker.to(device)
|
||||
video = self.video_processor.postprocess_video(video, output_type="np")
|
||||
video = (video * 255).astype(np.uint8)
|
||||
video_batch = []
|
||||
for vid in video:
|
||||
vid = self.safety_checker.check_video_safety(vid)
|
||||
video_batch.append(vid)
|
||||
video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1
|
||||
video = torch.from_numpy(video).permute(0, 4, 1, 2, 3)
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
self.safety_checker.to("cpu")
|
||||
else:
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return CosmosPipelineOutput(frames=video)
|
||||
@@ -131,7 +131,7 @@ def retrieve_timesteps(
|
||||
|
||||
class CosmosTextToWorldPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-video generation using [Cosmos](https://github.com/NVIDIA/Cosmos).
|
||||
Pipeline for text-to-world generation using [Cosmos Predict1](https://github.com/nvidia-cosmos/cosmos-predict1).
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
@@ -426,12 +426,12 @@ class CosmosTextToWorldPipeline(DiffusionPipeline):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, defaults to `1280`):
|
||||
The width in pixels of the generated image.
|
||||
num_frames (`int`, defaults to `129`):
|
||||
num_frames (`int`, defaults to `121`):
|
||||
The number of frames in the generated video.
|
||||
num_inference_steps (`int`, defaults to `50`):
|
||||
num_inference_steps (`int`, defaults to `36`):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, defaults to `6.0`):
|
||||
guidance_scale (`float`, defaults to `7.0`):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion
|
||||
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
@@ -457,9 +457,6 @@ class CosmosTextToWorldPipeline(DiffusionPipeline):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple.
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
||||
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
||||
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
||||
|
||||
@@ -174,7 +174,8 @@ def retrieve_latents(
|
||||
|
||||
class CosmosVideoToWorldPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for image-to-video and video-to-video generation using [Cosmos](https://github.com/NVIDIA/Cosmos).
|
||||
Pipeline for image-to-world and video-to-world generation using [Cosmos
|
||||
Predict-1](https://github.com/nvidia-cosmos/cosmos-predict1).
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
@@ -541,12 +542,12 @@ class CosmosVideoToWorldPipeline(DiffusionPipeline):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, defaults to `1280`):
|
||||
The width in pixels of the generated image.
|
||||
num_frames (`int`, defaults to `129`):
|
||||
num_frames (`int`, defaults to `121`):
|
||||
The number of frames in the generated video.
|
||||
num_inference_steps (`int`, defaults to `50`):
|
||||
num_inference_steps (`int`, defaults to `36`):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, defaults to `6.0`):
|
||||
guidance_scale (`float`, defaults to `7.0`):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion
|
||||
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
@@ -572,9 +573,6 @@ class CosmosVideoToWorldPipeline(DiffusionPipeline):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple.
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
||||
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
||||
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
||||
|
||||
@@ -1,14 +1,20 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from diffusers.utils import BaseOutput
|
||||
from diffusers.utils import BaseOutput, get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CosmosPipelineOutput(BaseOutput):
|
||||
r"""
|
||||
Output class for Cosmos pipelines.
|
||||
Output class for Cosmos any-to-world/video pipelines.
|
||||
|
||||
Args:
|
||||
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
||||
@@ -18,3 +24,17 @@ class CosmosPipelineOutput(BaseOutput):
|
||||
"""
|
||||
|
||||
frames: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class CosmosImagePipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for Cosmos any-to-image pipelines.
|
||||
|
||||
Args:
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
||||
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
||||
"""
|
||||
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
|
||||
@@ -21,7 +21,7 @@ from ...models import UNet1DModel
|
||||
from ...schedulers import SchedulerMixin
|
||||
from ...utils import is_torch_xla_available, logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
|
||||
from ..pipeline_utils import AudioPipelineOutput, DeprecatedPipelineMixin, DiffusionPipeline
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
@@ -34,7 +34,7 @@ else:
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class DanceDiffusionPipeline(DiffusionPipeline):
|
||||
class DanceDiffusionPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for audio generation.
|
||||
|
||||
@@ -49,6 +49,7 @@ class DanceDiffusionPipeline(DiffusionPipeline):
|
||||
[`IPNDMScheduler`].
|
||||
"""
|
||||
|
||||
_last_supported_version = "0.33.1"
|
||||
model_cpu_offload_seq = "unet"
|
||||
|
||||
def __init__(self, unet: UNet1DModel, scheduler: SchedulerMixin):
|
||||
|
||||
@@ -898,6 +898,7 @@ class FluxPipeline(
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
self.scheduler.set_begin_index(0)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
|
||||
@@ -1193,6 +1193,11 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterM
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
if padding_mask_crop is not None:
|
||||
image = [
|
||||
self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image
|
||||
]
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ from ...utils import (
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
@@ -97,9 +97,11 @@ class I2VGenXLPipelineOutput(BaseOutput):
|
||||
|
||||
|
||||
class I2VGenXLPipeline(
|
||||
DeprecatedPipelineMixin,
|
||||
DiffusionPipeline,
|
||||
StableDiffusionMixin,
|
||||
):
|
||||
_last_supported_version = "0.33.1"
|
||||
r"""
|
||||
Pipeline for image-to-video generation as proposed in [I2VGenXL](https://i2vgen-xl.github.io/).
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ from ...utils import (
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import AudioPipelineOutput, DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
|
||||
|
||||
if is_librosa_available():
|
||||
@@ -76,7 +76,8 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class MusicLDMPipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
class MusicLDMPipeline(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin):
|
||||
_last_supported_version = "0.33.1"
|
||||
r"""
|
||||
Pipeline for text-to-audio generation using MusicLDM.
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from ...utils import deprecate, is_torch_xla_available, logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
from ..stable_diffusion import StableDiffusionPipelineOutput
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from .image_encoder import PaintByExampleImageEncoder
|
||||
@@ -155,7 +155,8 @@ def prepare_mask_and_masked_image(image, mask):
|
||||
return mask, masked_image
|
||||
|
||||
|
||||
class PaintByExamplePipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
class PaintByExamplePipeline(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin):
|
||||
_last_supported_version = "0.33.1"
|
||||
r"""
|
||||
<Tip warning={true}>
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ from ...utils import (
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..free_init_utils import FreeInitMixin
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
@@ -132,6 +132,7 @@ class PIAPipelineOutput(BaseOutput):
|
||||
|
||||
|
||||
class PIAPipeline(
|
||||
DeprecatedPipelineMixin,
|
||||
DiffusionPipeline,
|
||||
StableDiffusionMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
@@ -140,6 +141,7 @@ class PIAPipeline(
|
||||
FromSingleFileMixin,
|
||||
FreeInitMixin,
|
||||
):
|
||||
_last_supported_version = "0.33.1"
|
||||
r"""
|
||||
Pipeline for text-to-video generation.
|
||||
|
||||
|
||||
@@ -139,6 +139,43 @@ class AudioPipelineOutput(BaseOutput):
|
||||
audios: np.ndarray
|
||||
|
||||
|
||||
class DeprecatedPipelineMixin:
|
||||
"""
|
||||
A mixin that can be used to mark a pipeline as deprecated.
|
||||
|
||||
Pipelines inheriting from this mixin will raise a warning when instantiated, indicating that they are deprecated
|
||||
and won't receive updates past the specified version. Tests will be skipped for pipelines that inherit from this
|
||||
mixin.
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
class MyDeprecatedPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
|
||||
_last_supported_version = "0.20.0"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
```
|
||||
"""
|
||||
|
||||
# Override this in the inheriting class to specify the last version that will support this pipeline
|
||||
_last_supported_version = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Get the class name for the warning message
|
||||
class_name = self.__class__.__name__
|
||||
|
||||
# Get the last supported version or use the current version if not specified
|
||||
version_info = getattr(self.__class__, "_last_supported_version", __version__)
|
||||
|
||||
# Raise a warning that this pipeline is deprecated
|
||||
logger.warning(
|
||||
f"The {class_name} has been deprecated and will not receive bug fixes or feature updates after Diffusers version {version_info}. "
|
||||
)
|
||||
|
||||
# Call the parent class's __init__ method
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
r"""
|
||||
Base class for all pipelines.
|
||||
@@ -632,14 +669,11 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
|
||||
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
||||
information.
|
||||
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
A map that specifies where each submodule should go. It doesn’t need to be defined for each
|
||||
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
|
||||
same device.
|
||||
|
||||
Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
|
||||
more information about each option see [designing a device
|
||||
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
||||
device_map (`str`, *optional*):
|
||||
Strategy that dictates how the different components of a pipeline should be placed on available
|
||||
devices. Currently, only "balanced" `device_map` is supported. Check out
|
||||
[this](https://huggingface.co/docs/diffusers/main/en/tutorials/inference_with_big_models#device-placement)
|
||||
to know more.
|
||||
max_memory (`Dict`, *optional*):
|
||||
A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
|
||||
each GPU and the available CPU RAM if unset.
|
||||
|
||||
+3
-2
@@ -11,7 +11,7 @@ from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyCh
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import deprecate, is_torch_xla_available, logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
from .pipeline_output import SemanticStableDiffusionPipelineOutput
|
||||
|
||||
|
||||
@@ -25,7 +25,8 @@ else:
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
class SemanticStableDiffusionPipeline(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin):
|
||||
_last_supported_version = "0.33.1"
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion with latent editing.
|
||||
|
||||
|
||||
+6
-2
@@ -37,7 +37,7 @@ from ...utils import (
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
from ..stable_diffusion import StableDiffusionPipelineOutput
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
@@ -179,7 +179,9 @@ class AttendExciteAttnProcessor:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin):
|
||||
class StableDiffusionAttendAndExcitePipeline(
|
||||
DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion and Attend-and-Excite.
|
||||
|
||||
@@ -209,6 +211,8 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, StableDiffusionM
|
||||
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||
"""
|
||||
|
||||
_last_supported_version = "0.33.1"
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
+8
-2
@@ -40,7 +40,7 @@ from ...utils import (
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
from ..stable_diffusion import StableDiffusionPipelineOutput
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
@@ -242,7 +242,11 @@ def preprocess_mask(mask, batch_size: int = 1):
|
||||
|
||||
|
||||
class StableDiffusionDiffEditPipeline(
|
||||
DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin
|
||||
DeprecatedPipelineMixin,
|
||||
DiffusionPipeline,
|
||||
StableDiffusionMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
):
|
||||
r"""
|
||||
<Tip warning={true}>
|
||||
@@ -282,6 +286,8 @@ class StableDiffusionDiffEditPipeline(
|
||||
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||
"""
|
||||
|
||||
_last_supported_version = "0.33.1"
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor", "inverse_scheduler"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
@@ -36,7 +36,7 @@ from ...utils import (
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
from ..stable_diffusion import StableDiffusionPipelineOutput
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
@@ -108,7 +108,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class StableDiffusionGLIGENPipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
class StableDiffusionGLIGENPipeline(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion with Grounded-Language-to-Image Generation (GLIGEN).
|
||||
|
||||
@@ -135,6 +135,8 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||
"""
|
||||
|
||||
_last_supported_version = "0.33.1"
|
||||
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
+4
-2
@@ -41,7 +41,7 @@ from ...utils import (
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
from ..stable_diffusion import StableDiffusionPipelineOutput
|
||||
from ..stable_diffusion.clip_image_project_model import CLIPImageProjection
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
@@ -160,7 +160,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
class StableDiffusionGLIGENTextImagePipeline(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion with Grounded-Language-to-Image Generation (GLIGEN).
|
||||
|
||||
@@ -193,6 +193,8 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline, StableDiffusionM
|
||||
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||
"""
|
||||
|
||||
_last_supported_version = "0.33.1"
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
+8
-2
@@ -42,7 +42,7 @@ from ...utils import (
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
from ..stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
||||
|
||||
|
||||
@@ -64,7 +64,11 @@ class ModelWrapper:
|
||||
|
||||
|
||||
class StableDiffusionKDiffusionPipeline(
|
||||
DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin
|
||||
DeprecatedPipelineMixin,
|
||||
DiffusionPipeline,
|
||||
StableDiffusionMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion.
|
||||
@@ -105,6 +109,8 @@ class StableDiffusionKDiffusionPipeline(
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
_last_supported_version = "0.33.1"
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
+4
-1
@@ -48,7 +48,7 @@ from ...utils import (
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
||||
|
||||
|
||||
@@ -88,6 +88,7 @@ class ModelWrapper:
|
||||
|
||||
|
||||
class StableDiffusionXLKDiffusionPipeline(
|
||||
DeprecatedPipelineMixin,
|
||||
DiffusionPipeline,
|
||||
StableDiffusionMixin,
|
||||
FromSingleFileMixin,
|
||||
@@ -95,6 +96,8 @@ class StableDiffusionXLKDiffusionPipeline(
|
||||
TextualInversionLoaderMixin,
|
||||
IPAdapterMixin,
|
||||
):
|
||||
_last_supported_version = "0.33.1"
|
||||
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion XL and k-diffusion.
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ from ...utils import (
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
|
||||
@@ -178,6 +178,7 @@ class LDM3DPipelineOutput(BaseOutput):
|
||||
|
||||
|
||||
class StableDiffusionLDM3DPipeline(
|
||||
DeprecatedPipelineMixin,
|
||||
DiffusionPipeline,
|
||||
StableDiffusionMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
@@ -185,6 +186,8 @@ class StableDiffusionLDM3DPipeline(
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
FromSingleFileMixin,
|
||||
):
|
||||
_last_supported_version = "0.33.1"
|
||||
|
||||
r"""
|
||||
Pipeline for text-to-image and 3D generation using LDM3D.
|
||||
|
||||
|
||||
+4
-1
@@ -33,7 +33,7 @@ from ...utils import (
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
from ..stable_diffusion import StableDiffusionPipelineOutput
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
@@ -156,12 +156,15 @@ def retrieve_timesteps(
|
||||
|
||||
|
||||
class StableDiffusionPanoramaPipeline(
|
||||
DeprecatedPipelineMixin,
|
||||
DiffusionPipeline,
|
||||
StableDiffusionMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
IPAdapterMixin,
|
||||
):
|
||||
_last_supported_version = "0.33.1"
|
||||
|
||||
r"""
|
||||
Pipeline for text-to-image generation using MultiDiffusion.
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import deprecate, is_torch_xla_available, logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
from . import StableDiffusionSafePipelineOutput
|
||||
from .safety_checker import SafeStableDiffusionSafetyChecker
|
||||
|
||||
@@ -29,7 +29,9 @@ else:
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class StableDiffusionPipelineSafe(DiffusionPipeline, StableDiffusionMixin, IPAdapterMixin):
|
||||
class StableDiffusionPipelineSafe(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin, IPAdapterMixin):
|
||||
_last_supported_version = "0.33.1"
|
||||
|
||||
r"""
|
||||
Pipeline based on the [`StableDiffusionPipeline`] for text-to-image generation using Safe Latent Diffusion.
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ from ...utils import (
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
from ..stable_diffusion import StableDiffusionPipelineOutput
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
@@ -107,7 +107,11 @@ class CrossAttnStoreProcessor:
|
||||
|
||||
|
||||
# Modified to get self-attention guidance scale in this paper (https://huggingface.co/papers/2210.00939) as an input
|
||||
class StableDiffusionSAGPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, IPAdapterMixin):
|
||||
class StableDiffusionSAGPipeline(
|
||||
DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, IPAdapterMixin
|
||||
):
|
||||
_last_supported_version = "0.33.1"
|
||||
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion.
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ from ...utils import (
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
from . import TextToVideoSDPipelineOutput
|
||||
|
||||
|
||||
@@ -68,8 +68,13 @@ EXAMPLE_DOC_STRING = """
|
||||
|
||||
|
||||
class TextToVideoSDPipeline(
|
||||
DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin
|
||||
DeprecatedPipelineMixin,
|
||||
DiffusionPipeline,
|
||||
StableDiffusionMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
):
|
||||
_last_supported_version = "0.33.1"
|
||||
r"""
|
||||
Pipeline for text-to-video generation.
|
||||
|
||||
|
||||
+7
-2
@@ -34,7 +34,7 @@ from ...utils import (
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
from . import TextToVideoSDPipelineOutput
|
||||
|
||||
|
||||
@@ -103,8 +103,13 @@ def retrieve_latents(
|
||||
|
||||
|
||||
class VideoToVideoSDPipeline(
|
||||
DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin
|
||||
DeprecatedPipelineMixin,
|
||||
DiffusionPipeline,
|
||||
StableDiffusionMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
):
|
||||
_last_supported_version = "0.33.1"
|
||||
r"""
|
||||
Pipeline for text-guided video-to-video generation.
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from ...utils import (
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
from ..stable_diffusion import StableDiffusionSafetyChecker
|
||||
|
||||
|
||||
@@ -296,12 +296,14 @@ def create_motion_field_and_warp_latents(motion_field_strength_x, motion_field_s
|
||||
|
||||
|
||||
class TextToVideoZeroPipeline(
|
||||
DeprecatedPipelineMixin,
|
||||
DiffusionPipeline,
|
||||
StableDiffusionMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
FromSingleFileMixin,
|
||||
):
|
||||
_last_supported_version = "0.33.1"
|
||||
r"""
|
||||
Pipeline for zero-shot text-to-video generation using Stable Diffusion.
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ from ...utils import (
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
|
||||
|
||||
if is_invisible_watermark_available():
|
||||
@@ -346,11 +346,13 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
|
||||
|
||||
class TextToVideoZeroSDXLPipeline(
|
||||
DeprecatedPipelineMixin,
|
||||
DiffusionPipeline,
|
||||
StableDiffusionMixin,
|
||||
StableDiffusionXLLoraLoaderMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
):
|
||||
_last_supported_version = "0.33.1"
|
||||
r"""
|
||||
Pipeline for zero-shot text-to-video generation using Stable Diffusion XL.
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel
|
||||
from ...schedulers import UnCLIPScheduler
|
||||
from ...utils import is_torch_xla_available, logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput
|
||||
from .text_proj import UnCLIPTextProjModel
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ else:
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class UnCLIPPipeline(DiffusionPipeline):
|
||||
class UnCLIPPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for text-to-image generation using unCLIP.
|
||||
|
||||
@@ -69,6 +69,7 @@ class UnCLIPPipeline(DiffusionPipeline):
|
||||
|
||||
"""
|
||||
|
||||
_last_supported_version = "0.33.1"
|
||||
_exclude_from_cpu_offload = ["prior"]
|
||||
|
||||
prior: PriorTransformer
|
||||
|
||||
@@ -29,7 +29,7 @@ from ...models import UNet2DConditionModel, UNet2DModel
|
||||
from ...schedulers import UnCLIPScheduler
|
||||
from ...utils import is_torch_xla_available, logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput
|
||||
from .text_proj import UnCLIPTextProjModel
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ else:
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class UnCLIPImageVariationPipeline(DiffusionPipeline):
|
||||
class UnCLIPImageVariationPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
|
||||
"""
|
||||
Pipeline to generate image variations from an input image using UnCLIP.
|
||||
|
||||
@@ -73,6 +73,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
|
||||
Scheduler used in the super resolution denoising process (a modified [`DDPMScheduler`]).
|
||||
"""
|
||||
|
||||
_last_supported_version = "0.33.1"
|
||||
decoder: UNet2DConditionModel
|
||||
text_proj: UnCLIPTextProjModel
|
||||
text_encoder: CLIPTextModelWithProjection
|
||||
|
||||
@@ -28,7 +28,7 @@ from ...utils import (
|
||||
)
|
||||
from ...utils.outputs import BaseOutput
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline
|
||||
from .modeling_text_decoder import UniDiffuserTextDecoder
|
||||
from .modeling_uvit import UniDiffuserModel
|
||||
|
||||
@@ -62,7 +62,7 @@ class ImageTextPipelineOutput(BaseOutput):
|
||||
text: Optional[Union[List[str], List[List[str]]]]
|
||||
|
||||
|
||||
class UniDiffuserPipeline(DiffusionPipeline):
|
||||
class UniDiffuserPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for a bimodal image-text model which supports unconditional text and image generation, text-conditioned
|
||||
image generation, image-conditioned text generation, and joint image-text generation.
|
||||
@@ -96,6 +96,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
original UniDiffuser paper uses the [`DPMSolverMultistepScheduler`] scheduler.
|
||||
"""
|
||||
|
||||
_last_supported_version = "0.33.1"
|
||||
# TODO: support for moving submodules for components with enable_model_cpu_offload
|
||||
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae->text_decoder"
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["pipeline_wan"] = ["WanPipeline"]
|
||||
_import_structure["pipeline_wan_i2v"] = ["WanImageToVideoPipeline"]
|
||||
_import_structure["pipeline_wan_vace"] = ["WanVACEPipeline"]
|
||||
_import_structure["pipeline_wan_video2video"] = ["WanVideoToVideoPipeline"]
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
@@ -35,6 +36,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .pipeline_wan import WanPipeline
|
||||
from .pipeline_wan_i2v import WanImageToVideoPipeline
|
||||
from .pipeline_wan_vace import WanVACEPipeline
|
||||
from .pipeline_wan_video2video import WanVideoToVideoPipeline
|
||||
|
||||
else:
|
||||
|
||||
@@ -388,8 +388,10 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
The prompt or prompts to guide the image generation. If not defined, pass `prompt_embeds` instead.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds`
|
||||
instead. Ignored when not using guidance (`guidance_scale` < `1`).
|
||||
height (`int`, defaults to `480`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, defaults to `832`):
|
||||
@@ -434,8 +436,9 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
|
||||
The dtype to use for the torch.amp.autocast.
|
||||
max_sequence_length (`int`, defaults to `512`):
|
||||
The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
|
||||
truncated. If the prompt is shorter, it will be padded to this length.
|
||||
|
||||
Examples:
|
||||
|
||||
|
||||
@@ -562,12 +562,10 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int`, *optional*, defaults to `512`):
|
||||
The maximum sequence length of the prompt.
|
||||
shift (`float`, *optional*, defaults to `5.0`):
|
||||
The shift of the flow.
|
||||
autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
|
||||
The dtype to use for the torch.amp.autocast.
|
||||
max_sequence_length (`int`, defaults to `512`):
|
||||
The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
|
||||
truncated. If the prompt is shorter, it will be padded to this length.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -0,0 +1,976 @@
|
||||
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import html
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import PIL.Image
|
||||
import regex as re
|
||||
import torch
|
||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...loaders import WanLoraLoaderMixin
|
||||
from ...models import AutoencoderKLWan, WanVACETransformer3DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import WanPipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
if is_ftfy_available():
|
||||
import ftfy
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> import PIL.Image
|
||||
>>> from diffusers import AutoencoderKLWan, WanVACEPipeline
|
||||
>>> from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
|
||||
>>> from diffusers.utils import export_to_video, load_image
|
||||
def prepare_video_and_mask(first_img: PIL.Image.Image, last_img: PIL.Image.Image, height: int, width: int, num_frames: int):
|
||||
first_img = first_img.resize((width, height))
|
||||
last_img = last_img.resize((width, height))
|
||||
frames = []
|
||||
frames.append(first_img)
|
||||
# Ideally, this should be 127.5 to match original code, but they perform computation on numpy arrays
|
||||
# whereas we are passing PIL images. If you choose to pass numpy arrays, you can set it to 127.5 to
|
||||
# match the original code.
|
||||
frames.extend([PIL.Image.new("RGB", (width, height), (128, 128, 128))] * (num_frames - 2))
|
||||
frames.append(last_img)
|
||||
mask_black = PIL.Image.new("L", (width, height), 0)
|
||||
mask_white = PIL.Image.new("L", (width, height), 255)
|
||||
mask = [mask_black, *[mask_white] * (num_frames - 2), mask_black]
|
||||
return frames, mask
|
||||
|
||||
>>> # Available checkpoints: Wan-AI/Wan2.1-VACE-1.3B-diffusers, Wan-AI/Wan2.1-VACE-14B-diffusers
|
||||
>>> model_id = "Wan-AI/Wan2.1-VACE-1.3B-diffusers"
|
||||
>>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
>>> pipe = WanVACEPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
|
||||
>>> flow_shift = 3.0 # 5.0 for 720P, 3.0 for 480P
|
||||
>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
|
||||
>>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
||||
>>> first_frame = load_image(
|
||||
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png"
|
||||
... )
|
||||
>>> last_frame = load_image(
|
||||
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png>>> "
|
||||
... )
|
||||
|
||||
>>> height = 512
|
||||
>>> width = 512
|
||||
>>> num_frames = 81
|
||||
>>> video, mask = prepare_video_and_mask(first_frame, last_frame, height, width, num_frames)
|
||||
|
||||
>>> output = pipe(
|
||||
... video=video,
|
||||
... mask=mask,
|
||||
... prompt=prompt,
|
||||
... negative_prompt=negative_prompt,
|
||||
... height=height,
|
||||
... width=width,
|
||||
... num_frames=num_frames,
|
||||
... num_inference_steps=30,
|
||||
... guidance_scale=5.0,
|
||||
... generator=torch.Generator().manual_seed(42),
|
||||
... ).frames[0]
|
||||
>>> export_to_video(output, "output.mp4", fps=16)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def basic_clean(text):
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text))
|
||||
return text.strip()
|
||||
|
||||
|
||||
def whitespace_clean(text):
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
def prompt_clean(text):
|
||||
text = whitespace_clean(basic_clean(text))
|
||||
return text
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for controllable generation using Wan.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
Args:
|
||||
tokenizer ([`T5Tokenizer`]):
|
||||
Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
|
||||
specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
||||
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
|
||||
transformer ([`WanTransformer3DModel`]):
|
||||
Conditional Transformer to denoise the input latents.
|
||||
scheduler ([`UniPCMultistepScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLWan`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: AutoTokenizer,
|
||||
text_encoder: UMT5EncoderModel,
|
||||
transformer: WanVACETransformer3DModel,
|
||||
vae: AutoencoderKLWan,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
|
||||
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
# Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_videos_per_prompt: int = 1,
|
||||
max_sequence_length: int = 226,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
prompt = [prompt_clean(u) for u in prompt]
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_attention_mask=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
|
||||
seq_lens = mask.gt(0).sum(dim=1).long()
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
|
||||
prompt_embeds = torch.stack(
|
||||
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
|
||||
)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 226,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier free guidance or not.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
video=None,
|
||||
mask=None,
|
||||
reference_images=None,
|
||||
):
|
||||
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
|
||||
if height % base != 0 or width % base != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by {base} but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
elif negative_prompt is not None and (
|
||||
not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
|
||||
):
|
||||
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
|
||||
|
||||
if video is not None:
|
||||
if mask is not None:
|
||||
if len(video) != len(mask):
|
||||
raise ValueError(
|
||||
f"Length of `video` {len(video)} and `mask` {len(mask)} do not match. Please make sure that"
|
||||
" they have the same length."
|
||||
)
|
||||
if reference_images is not None:
|
||||
is_pil_image = isinstance(reference_images, PIL.Image.Image)
|
||||
is_list_of_pil_images = isinstance(reference_images, list) and all(
|
||||
isinstance(ref_img, PIL.Image.Image) for ref_img in reference_images
|
||||
)
|
||||
is_list_of_list_of_pil_images = isinstance(reference_images, list) and all(
|
||||
isinstance(ref_img, list) and all(isinstance(ref_img_, PIL.Image.Image) for ref_img_ in ref_img)
|
||||
for ref_img in reference_images
|
||||
)
|
||||
if not (is_pil_image or is_list_of_pil_images or is_list_of_list_of_pil_images):
|
||||
raise ValueError(
|
||||
"`reference_images` has to be of type `PIL.Image.Image` or `list` of `PIL.Image.Image`, or "
|
||||
"`list` of `list` of `PIL.Image.Image`, but is {type(reference_images)}"
|
||||
)
|
||||
if is_list_of_list_of_pil_images and len(reference_images) != 1:
|
||||
raise ValueError(
|
||||
"The pipeline only supports generating one video at a time at the moment. When passing a list "
|
||||
"of list of reference images, where the outer list corresponds to the batch size and the inner "
|
||||
"list corresponds to list of conditioning images per video, please make sure to only pass "
|
||||
"one inner list of reference images (i.e., `[[<image1>, <image2>, ...]]`"
|
||||
)
|
||||
elif mask is not None:
|
||||
raise ValueError("`mask` can only be passed if `video` is passed as well.")
|
||||
|
||||
def preprocess_conditions(
|
||||
self,
|
||||
video: Optional[List[PipelineImageInput]] = None,
|
||||
mask: Optional[List[PipelineImageInput]] = None,
|
||||
reference_images: Optional[Union[PIL.Image.Image, List[PIL.Image.Image], List[List[PIL.Image.Image]]]] = None,
|
||||
batch_size: int = 1,
|
||||
height: int = 480,
|
||||
width: int = 832,
|
||||
num_frames: int = 81,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
if video is not None:
|
||||
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
|
||||
video_height, video_width = self.video_processor.get_default_height_width(video[0])
|
||||
|
||||
if video_height * video_width > height * width:
|
||||
scale = min(width / video_width, height / video_height)
|
||||
video_height, video_width = int(video_height * scale), int(video_width * scale)
|
||||
|
||||
if video_height % base != 0 or video_width % base != 0:
|
||||
logger.warning(
|
||||
f"Video height and width should be divisible by {base}, but got {video_height} and {video_width}. "
|
||||
)
|
||||
video_height = (video_height // base) * base
|
||||
video_width = (video_width // base) * base
|
||||
|
||||
assert video_height * video_width <= height * width
|
||||
|
||||
video = self.video_processor.preprocess_video(video, video_height, video_width)
|
||||
image_size = (video_height, video_width) # Use the height/width of video (with possible rescaling)
|
||||
else:
|
||||
video = torch.zeros(batch_size, 3, num_frames, height, width, dtype=dtype, device=device)
|
||||
image_size = (height, width) # Use the height/width provider by user
|
||||
|
||||
if mask is not None:
|
||||
mask = self.video_processor.preprocess_video(mask, image_size[0], image_size[1])
|
||||
mask = torch.clamp((mask + 1) / 2, min=0, max=1)
|
||||
else:
|
||||
mask = torch.ones_like(video)
|
||||
|
||||
video = video.to(dtype=dtype, device=device)
|
||||
mask = mask.to(dtype=dtype, device=device)
|
||||
|
||||
# Make a list of list of images where the outer list corresponds to video batch size and the inner list
|
||||
# corresponds to list of conditioning images per video
|
||||
if reference_images is None or isinstance(reference_images, PIL.Image.Image):
|
||||
reference_images = [[reference_images] for _ in range(video.shape[0])]
|
||||
elif isinstance(reference_images, (list, tuple)) and isinstance(next(iter(reference_images)), PIL.Image.Image):
|
||||
reference_images = [reference_images]
|
||||
elif (
|
||||
isinstance(reference_images, (list, tuple))
|
||||
and isinstance(next(iter(reference_images)), list)
|
||||
and isinstance(next(iter(reference_images[0])), PIL.Image.Image)
|
||||
):
|
||||
reference_images = reference_images
|
||||
else:
|
||||
raise ValueError(
|
||||
"`reference_images` has to be of type `PIL.Image.Image` or `list` of `PIL.Image.Image`, or "
|
||||
"`list` of `list` of `PIL.Image.Image`, but is {type(reference_images)}"
|
||||
)
|
||||
|
||||
if video.shape[0] != len(reference_images):
|
||||
raise ValueError(
|
||||
f"Batch size of `video` {video.shape[0]} and length of `reference_images` {len(reference_images)} does not match."
|
||||
)
|
||||
|
||||
ref_images_lengths = [len(reference_images_batch) for reference_images_batch in reference_images]
|
||||
if any(l != ref_images_lengths[0] for l in ref_images_lengths):
|
||||
raise ValueError(
|
||||
f"All batches of `reference_images` should have the same length, but got {ref_images_lengths}. Support for this "
|
||||
"may be added in the future."
|
||||
)
|
||||
|
||||
reference_images_preprocessed = []
|
||||
for i, reference_images_batch in enumerate(reference_images):
|
||||
preprocessed_images = []
|
||||
for j, image in enumerate(reference_images_batch):
|
||||
if image is None:
|
||||
continue
|
||||
image = self.video_processor.preprocess(image, None, None)
|
||||
img_height, img_width = image.shape[-2:]
|
||||
scale = min(image_size[0] / img_height, image_size[1] / img_width)
|
||||
new_height, new_width = int(img_height * scale), int(img_width * scale)
|
||||
resized_image = torch.nn.functional.interpolate(
|
||||
image, size=(new_height, new_width), mode="bilinear", align_corners=False
|
||||
).squeeze(0) # [C, H, W]
|
||||
top = (image_size[0] - new_height) // 2
|
||||
left = (image_size[1] - new_width) // 2
|
||||
canvas = torch.ones(3, *image_size, device=device, dtype=dtype)
|
||||
canvas[:, top : top + new_height, left : left + new_width] = resized_image
|
||||
preprocessed_images.append(canvas)
|
||||
reference_images_preprocessed.append(preprocessed_images)
|
||||
|
||||
return video, mask, reference_images_preprocessed
|
||||
|
||||
def prepare_video_latents(
|
||||
self,
|
||||
video: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
reference_images: Optional[List[List[torch.Tensor]]] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> torch.Tensor:
|
||||
device = device or self._execution_device
|
||||
|
||||
if isinstance(generator, list):
|
||||
# TODO: support this
|
||||
raise ValueError("Passing a list of generators is not yet supported. This may be supported in the future.")
|
||||
|
||||
if reference_images is None:
|
||||
# For each batch of video, we set no re
|
||||
# ference image (as one or more can be passed by user)
|
||||
reference_images = [[None] for _ in range(video.shape[0])]
|
||||
else:
|
||||
if video.shape[0] != len(reference_images):
|
||||
raise ValueError(
|
||||
f"Batch size of `video` {video.shape[0]} and length of `reference_images` {len(reference_images)} does not match."
|
||||
)
|
||||
|
||||
if video.shape[0] != 1:
|
||||
# TODO: support this
|
||||
raise ValueError(
|
||||
"Generating with more than one video is not yet supported. This may be supported in the future."
|
||||
)
|
||||
|
||||
vae_dtype = self.vae.dtype
|
||||
video = video.to(dtype=vae_dtype)
|
||||
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean, device=device, dtype=torch.float32).view(
|
||||
1, self.vae.config.z_dim, 1, 1, 1
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std, device=device, dtype=torch.float32).view(
|
||||
1, self.vae.config.z_dim, 1, 1, 1
|
||||
)
|
||||
|
||||
if mask is None:
|
||||
latents = retrieve_latents(self.vae.encode(video), generator, sample_mode="argmax").unbind(0)
|
||||
latents = ((latents.float() - latents_mean) * latents_std).to(vae_dtype)
|
||||
else:
|
||||
mask = mask.to(dtype=vae_dtype)
|
||||
mask = torch.where(mask > 0.5, 1.0, 0.0)
|
||||
inactive = video * (1 - mask)
|
||||
reactive = video * mask
|
||||
inactive = retrieve_latents(self.vae.encode(inactive), generator, sample_mode="argmax")
|
||||
reactive = retrieve_latents(self.vae.encode(reactive), generator, sample_mode="argmax")
|
||||
inactive = ((inactive.float() - latents_mean) * latents_std).to(vae_dtype)
|
||||
reactive = ((reactive.float() - latents_mean) * latents_std).to(vae_dtype)
|
||||
latents = torch.cat([inactive, reactive], dim=1)
|
||||
|
||||
latent_list = []
|
||||
for latent, reference_images_batch in zip(latents, reference_images):
|
||||
for reference_image in reference_images_batch:
|
||||
assert reference_image.ndim == 3
|
||||
reference_image = reference_image.to(dtype=vae_dtype)
|
||||
reference_image = reference_image[None, :, None, :, :] # [1, C, 1, H, W]
|
||||
reference_latent = retrieve_latents(self.vae.encode(reference_image), generator, sample_mode="argmax")
|
||||
reference_latent = ((reference_latent.float() - latents_mean) * latents_std).to(vae_dtype)
|
||||
reference_latent = reference_latent.squeeze(0) # [C, 1, H, W]
|
||||
reference_latent = torch.cat([reference_latent, torch.zeros_like(reference_latent)], dim=0)
|
||||
latent = torch.cat([reference_latent.squeeze(0), latent], dim=1)
|
||||
latent_list.append(latent)
|
||||
return torch.stack(latent_list)
|
||||
|
||||
def prepare_masks(
|
||||
self,
|
||||
mask: torch.Tensor,
|
||||
reference_images: Optional[List[torch.Tensor]] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
) -> torch.Tensor:
|
||||
if isinstance(generator, list):
|
||||
# TODO: support this
|
||||
raise ValueError("Passing a list of generators is not yet supported. This may be supported in the future.")
|
||||
|
||||
if reference_images is None:
|
||||
# For each batch of video, we set no reference image (as one or more can be passed by user)
|
||||
reference_images = [[None] for _ in range(mask.shape[0])]
|
||||
else:
|
||||
if mask.shape[0] != len(reference_images):
|
||||
raise ValueError(
|
||||
f"Batch size of `mask` {mask.shape[0]} and length of `reference_images` {len(reference_images)} does not match."
|
||||
)
|
||||
|
||||
if mask.shape[0] != 1:
|
||||
# TODO: support this
|
||||
raise ValueError(
|
||||
"Generating with more than one video is not yet supported. This may be supported in the future."
|
||||
)
|
||||
|
||||
transformer_patch_size = self.transformer.config.patch_size[1]
|
||||
|
||||
mask_list = []
|
||||
for mask_, reference_images_batch in zip(mask, reference_images):
|
||||
num_channels, num_frames, height, width = mask_.shape
|
||||
new_num_frames = (num_frames + self.vae_scale_factor_temporal - 1) // self.vae_scale_factor_temporal
|
||||
new_height = height // (self.vae_scale_factor_spatial * transformer_patch_size) * transformer_patch_size
|
||||
new_width = width // (self.vae_scale_factor_spatial * transformer_patch_size) * transformer_patch_size
|
||||
mask_ = mask_[0, :, :, :]
|
||||
mask_ = mask_.view(
|
||||
num_frames, new_height, self.vae_scale_factor_spatial, new_width, self.vae_scale_factor_spatial
|
||||
)
|
||||
mask_ = mask_.permute(2, 4, 0, 1, 3).flatten(0, 1) # [8x8, num_frames, new_height, new_width]
|
||||
mask_ = torch.nn.functional.interpolate(
|
||||
mask_.unsqueeze(0), size=(new_num_frames, new_height, new_width), mode="nearest-exact"
|
||||
).squeeze(0)
|
||||
num_ref_images = len(reference_images_batch)
|
||||
if num_ref_images > 0:
|
||||
mask_padding = torch.zeros_like(mask_[:, :num_ref_images, :, :])
|
||||
mask_ = torch.cat([mask_, mask_padding], dim=1)
|
||||
mask_list.append(mask_)
|
||||
return torch.stack(mask_list)
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_channels_latents: int = 16,
|
||||
height: int = 480,
|
||||
width: int = 832,
|
||||
num_frames: int = 81,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
num_latent_frames,
|
||||
int(height) // self.vae_scale_factor_spatial,
|
||||
int(width) // self.vae_scale_factor_spatial,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
return latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Union[str, List[str]] = None,
|
||||
video: Optional[List[PipelineImageInput]] = None,
|
||||
mask: Optional[List[PipelineImageInput]] = None,
|
||||
reference_images: Optional[List[PipelineImageInput]] = None,
|
||||
conditioning_scale: Union[float, List[float], torch.Tensor] = 1.0,
|
||||
height: int = 480,
|
||||
width: int = 832,
|
||||
num_frames: int = 81,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 5.0,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "np",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`
|
||||
instead.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
video (`List[PIL.Image.Image]`, *optional*):
|
||||
The input video or videos to be used as a starting point for the generation. The video should be a list
|
||||
of PIL images, a numpy array, or a torch tensor. Currently, the pipeline only supports generating one
|
||||
video at a time.
|
||||
mask (`List[PIL.Image.Image]`, *optional*):
|
||||
The input mask defines which video regions to condition on and which to generate. Black areas in the
|
||||
mask indicate conditioning regions, while white areas indicate regions for generation. The mask should
|
||||
be a list of PIL images, a numpy array, or a torch tensor. Currently supports generating a single video
|
||||
at a time.
|
||||
reference_images (`List[PIL.Image.Image]`, *optional*):
|
||||
A list of one or more reference images as extra conditioning for the generation. For example, if you
|
||||
are trying to inpaint a video to change the character, you can pass reference images of the new
|
||||
character here. Refer to the Diffusers [examples](https://github.com/huggingface/diffusers/pull/11582)
|
||||
and original [user
|
||||
guide](https://github.com/ali-vilab/VACE/blob/0897c6d055d7d9ea9e191dce763006664d9780f8/UserGuide.md)
|
||||
for a full list of supported tasks and use cases.
|
||||
conditioning_scale (`float`, `List[float]`, `torch.Tensor`, defaults to `1.0`):
|
||||
The conditioning scale to be applied when adding the control conditioning latent stream to the
|
||||
denoising latent stream in each control layer of the model. If a float is provided, it will be applied
|
||||
uniformly to all layers. If a list or tensor is provided, it should have the same length as the number
|
||||
of control layers in the model (`len(transformer.config.vace_layers)`).
|
||||
height (`int`, defaults to `480`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, defaults to `832`):
|
||||
The width in pixels of the generated image.
|
||||
num_frames (`int`, defaults to `81`):
|
||||
The number of frames in the generated video.
|
||||
num_inference_steps (`int`, defaults to `50`):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, defaults to `5.0`):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor is generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
||||
provided, text embeddings are generated from the `prompt` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"np"`):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
||||
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
||||
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
||||
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
||||
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int`, defaults to `512`):
|
||||
The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
|
||||
truncated. If the prompt is shorter, it will be padded to this length.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~WanPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where
|
||||
the first element is a list with the generated images and the second element is a list of `bool`s
|
||||
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# Simplification of implementation for now
|
||||
if not isinstance(prompt, str):
|
||||
raise ValueError("Passing a list of prompts is not yet supported. This may be supported in the future.")
|
||||
if num_videos_per_prompt != 1:
|
||||
raise ValueError(
|
||||
"Generating multiple videos per prompt is not yet supported. This may be supported in the future."
|
||||
)
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
video,
|
||||
mask,
|
||||
reference_images,
|
||||
)
|
||||
|
||||
if num_frames % self.vae_scale_factor_temporal != 1:
|
||||
logger.warning(
|
||||
f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
|
||||
)
|
||||
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
|
||||
num_frames = max(num_frames, 1)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
vae_dtype = self.vae.dtype
|
||||
transformer_dtype = self.transformer.dtype
|
||||
|
||||
if isinstance(conditioning_scale, (int, float)):
|
||||
conditioning_scale = [conditioning_scale] * len(self.transformer.config.vace_layers)
|
||||
if isinstance(conditioning_scale, list):
|
||||
if len(conditioning_scale) != len(self.transformer.config.vace_layers):
|
||||
raise ValueError(
|
||||
f"Length of `conditioning_scale` {len(conditioning_scale)} does not match number of layers {len(self.transformer.config.vace_layers)}."
|
||||
)
|
||||
conditioning_scale = torch.tensor(conditioning_scale)
|
||||
if isinstance(conditioning_scale, torch.Tensor):
|
||||
if conditioning_scale.size(0) != len(self.transformer.config.vace_layers):
|
||||
raise ValueError(
|
||||
f"Length of `conditioning_scale` {conditioning_scale.size(0)} does not match number of layers {len(self.transformer.config.vace_layers)}."
|
||||
)
|
||||
conditioning_scale = conditioning_scale.to(device=device, dtype=transformer_dtype)
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(transformer_dtype)
|
||||
if negative_prompt_embeds is not None:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
video, mask, reference_images = self.preprocess_conditions(
|
||||
video,
|
||||
mask,
|
||||
reference_images,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
torch.float32,
|
||||
device,
|
||||
)
|
||||
num_reference_images = len(reference_images[0])
|
||||
|
||||
conditioning_latents = self.prepare_video_latents(video, mask, reference_images, generator, device)
|
||||
mask = self.prepare_masks(mask, reference_images, generator)
|
||||
conditioning_latents = torch.cat([conditioning_latents, mask], dim=1)
|
||||
conditioning_latents = conditioning_latents.to(transformer_dtype)
|
||||
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
num_frames + num_reference_images * self.vae_scale_factor_temporal,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
if conditioning_latents.shape[2] != latents.shape[2]:
|
||||
logger.warning(
|
||||
"The number of frames in the conditioning latents does not match the number of frames to be generated. Generation quality may be affected."
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = latents.to(transformer_dtype)
|
||||
timestep = t.expand(latents.shape[0])
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
control_hidden_states=conditioning_latents,
|
||||
control_hidden_states_scale=conditioning_scale,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
control_hidden_states=conditioning_latents,
|
||||
control_hidden_states_scale=conditioning_scale,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
latents = latents[:, :, num_reference_images:]
|
||||
latents = latents.to(vae_dtype)
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = latents / latents_std + latents_mean
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return WanPipelineOutput(frames=video)
|
||||
@@ -419,12 +419,7 @@ class WanVideoToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
if isinstance(generator, list):
|
||||
init_latents = [
|
||||
retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
|
||||
]
|
||||
else:
|
||||
init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
|
||||
init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), sample_mode="argmax") for vid in video]
|
||||
|
||||
init_latents = torch.cat(init_latents, dim=0).to(dtype)
|
||||
|
||||
@@ -441,7 +436,7 @@ class WanVideoToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
if hasattr(self.scheduler, "add_noise"):
|
||||
latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
||||
else:
|
||||
latents = self.scheduelr.scale_noise(init_latents, timestep, noise)
|
||||
latents = self.scheduler.scale_noise(init_latents, timestep, noise)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
@@ -513,7 +508,7 @@ class WanVideoToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`
|
||||
instead.
|
||||
height (`int`, defaults to `480`):
|
||||
The height in pixels of the generated image.
|
||||
@@ -530,6 +525,8 @@ class WanVideoToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
||||
the text `prompt`, usually at the expense of lower image quality.
|
||||
strength (`float`, defaults to `0.8`):
|
||||
Higher strength leads to more differences between original image and generated video.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
@@ -559,8 +556,9 @@ class WanVideoToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
|
||||
The dtype to use for the torch.amp.autocast.
|
||||
max_sequence_length (`int`, defaults to `512`):
|
||||
The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
|
||||
truncated. If the prompt is shorter, it will be padded to this length.
|
||||
|
||||
Examples:
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
|
||||
from ...schedulers import DDPMWuerstchenScheduler
|
||||
from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput
|
||||
from .modeling_paella_vq_model import PaellaVQModel
|
||||
from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt
|
||||
|
||||
@@ -56,7 +56,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class WuerstchenDecoderPipeline(DiffusionPipeline):
|
||||
class WuerstchenDecoderPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for generating images from the Wuerstchen model.
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...schedulers import DDPMWuerstchenScheduler
|
||||
from ...utils import deprecate, replace_example_docstring
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline
|
||||
from .modeling_paella_vq_model import PaellaVQModel
|
||||
from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt
|
||||
from .modeling_wuerstchen_prior import WuerstchenPrior
|
||||
@@ -40,7 +40,7 @@ TEXT2IMAGE_EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class WuerstchenCombinedPipeline(DiffusionPipeline):
|
||||
class WuerstchenCombinedPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
|
||||
"""
|
||||
Combined Pipeline for text-to-image generation using Wuerstchen
|
||||
|
||||
@@ -68,6 +68,7 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
|
||||
The scheduler to be used for prior pipeline.
|
||||
"""
|
||||
|
||||
_last_supported_version = "0.33.1"
|
||||
_load_connected_pipes = True
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -493,7 +493,7 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
|
||||
if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
|
||||
is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp")
|
||||
if is_floating_quant_type and not self._is_cuda_capability_atleast_8_9():
|
||||
if is_floating_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9():
|
||||
raise ValueError(
|
||||
f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You "
|
||||
f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`."
|
||||
@@ -645,7 +645,7 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
QUANTIZATION_TYPES.update(INT8_QUANTIZATION_TYPES)
|
||||
QUANTIZATION_TYPES.update(UINTX_QUANTIZATION_DTYPES)
|
||||
|
||||
if cls._is_cuda_capability_atleast_8_9():
|
||||
if cls._is_xpu_or_cuda_capability_atleast_8_9():
|
||||
QUANTIZATION_TYPES.update(FLOATX_QUANTIZATION_TYPES)
|
||||
|
||||
return QUANTIZATION_TYPES
|
||||
@@ -655,14 +655,16 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_cuda_capability_atleast_8_9() -> bool:
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("TorchAO requires a CUDA compatible GPU and installation of PyTorch.")
|
||||
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
if major == 8:
|
||||
return minor >= 9
|
||||
return major >= 9
|
||||
def _is_xpu_or_cuda_capability_atleast_8_9() -> bool:
|
||||
if torch.cuda.is_available():
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
if major == 8:
|
||||
return minor >= 9
|
||||
return major >= 9
|
||||
elif torch.xpu.is_available():
|
||||
return True
|
||||
else:
|
||||
raise RuntimeError("TorchAO requires a CUDA compatible GPU or Intel XPU and installation of PyTorch.")
|
||||
|
||||
def get_apply_tensor_subclass(self):
|
||||
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
|
||||
|
||||
@@ -247,6 +247,14 @@ def _set_state_dict_into_text_encoder(
|
||||
set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default")
|
||||
|
||||
|
||||
def _collate_lora_metadata(modules_to_save: Dict[str, torch.nn.Module]) -> Dict[str, Any]:
|
||||
metadatas = {}
|
||||
for module_name, module in modules_to_save.items():
|
||||
if module is not None:
|
||||
metadatas[f"{module_name}_lora_adapter_metadata"] = module.peft_config["default"].to_dict()
|
||||
return metadatas
|
||||
|
||||
|
||||
def compute_density_for_timestep_sampling(
|
||||
weighting_scheme: str,
|
||||
batch_size: int,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,300 @@
|
||||
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Doc utilities: Utilities related to documentation
|
||||
|
||||
Adapted from:
|
||||
https://github.com/huggingface/transformers/blob/5a95ed5ca0826c867e35e52f698db4d8fc907bcb/src/transformers/utils/doc.py
|
||||
"""
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import re
|
||||
import textwrap
|
||||
import types
|
||||
from collections import OrderedDict
|
||||
|
||||
from ..pipelines.auto_pipeline import AUTO_TEXT2IMAGE_PIPELINES_MAPPING
|
||||
|
||||
|
||||
def get_docstring_indentation_level(func):
|
||||
"""Return the indentation level of the start of the docstring of a class or function (or method)."""
|
||||
# We assume classes are always defined in the global scope
|
||||
if inspect.isclass(func):
|
||||
return 4
|
||||
source = inspect.getsource(func)
|
||||
first_line = source.splitlines()[0]
|
||||
function_def_level = len(first_line) - len(first_line.lstrip())
|
||||
return 4 + function_def_level
|
||||
|
||||
|
||||
def add_start_docstrings(*docstr):
|
||||
def docstring_decorator(fn):
|
||||
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
|
||||
return fn
|
||||
|
||||
return docstring_decorator
|
||||
|
||||
|
||||
def add_start_docstrings_to_model_forward(*docstr):
|
||||
def docstring_decorator(fn):
|
||||
class_name = f"[`{fn.__qualname__.split('.')[0]}`]"
|
||||
intro = rf""" The {class_name} forward method, overrides the `__call__` special method.
|
||||
|
||||
<Tip>
|
||||
|
||||
Although the recipe for forward pass needs to be defined within this function, one should call the [`Module`]
|
||||
instance afterwards instead of this since the former takes care of running the pre and post processing steps while
|
||||
the latter silently ignores them.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
|
||||
correct_indentation = get_docstring_indentation_level(fn)
|
||||
current_doc = fn.__doc__ if fn.__doc__ is not None else ""
|
||||
try:
|
||||
first_non_empty = next(line for line in current_doc.splitlines() if line.strip() != "")
|
||||
doc_indentation = len(first_non_empty) - len(first_non_empty.lstrip())
|
||||
except StopIteration:
|
||||
doc_indentation = correct_indentation
|
||||
|
||||
docs = docstr
|
||||
# In this case, the correct indentation level (class method, 2 Python levels) was respected, and we should
|
||||
# correctly reindent everything. Otherwise, the doc uses a single indentation level
|
||||
if doc_indentation == 4 + correct_indentation:
|
||||
docs = [textwrap.indent(textwrap.dedent(doc), " " * correct_indentation) for doc in docstr]
|
||||
intro = textwrap.indent(textwrap.dedent(intro), " " * correct_indentation)
|
||||
|
||||
docstring = "".join(docs) + current_doc
|
||||
fn.__doc__ = intro + docstring
|
||||
return fn
|
||||
|
||||
return docstring_decorator
|
||||
|
||||
|
||||
def add_end_docstrings(*docstr):
|
||||
def docstring_decorator(fn):
|
||||
fn.__doc__ = (fn.__doc__ if fn.__doc__ is not None else "") + "".join(docstr)
|
||||
return fn
|
||||
|
||||
return docstring_decorator
|
||||
|
||||
|
||||
PT_RETURN_INTRODUCTION = r"""
|
||||
Returns:
|
||||
[`{full_output_type}`] or `tuple(torch.FloatTensor)`: A [`{full_output_type}`] or a tuple of
|
||||
`torch.FloatTensor` (if `return_dict=False` is passed) comprising various
|
||||
elements depending on the model and inputs.
|
||||
|
||||
"""
|
||||
|
||||
TEXT_TO_IMAGE_PIPELINE_CLASSES = list({p[0] for p in AUTO_TEXT2IMAGE_PIPELINES_MAPPING})
|
||||
|
||||
def _get_indent(t):
|
||||
"""Returns the indentation in the first line of t"""
|
||||
search = re.search(r"^(\s*)\S", t)
|
||||
return "" if search is None else search.groups()[0]
|
||||
|
||||
|
||||
def _convert_output_args_doc(output_args_doc):
|
||||
"""Convert output_args_doc to display properly."""
|
||||
# Split output_arg_doc in blocks argument/description
|
||||
indent = _get_indent(output_args_doc)
|
||||
blocks = []
|
||||
current_block = ""
|
||||
for line in output_args_doc.split("\n"):
|
||||
# If the indent is the same as the beginning, the line is the name of new arg.
|
||||
if _get_indent(line) == indent:
|
||||
if len(current_block) > 0:
|
||||
blocks.append(current_block[:-1])
|
||||
current_block = f"{line}\n"
|
||||
else:
|
||||
# Otherwise it's part of the description of the current arg.
|
||||
# We need to remove 2 spaces to the indentation.
|
||||
current_block += f"{line[2:]}\n"
|
||||
blocks.append(current_block[:-1])
|
||||
|
||||
# Format each block for proper rendering
|
||||
for i in range(len(blocks)):
|
||||
blocks[i] = re.sub(r"^(\s+)(\S+)(\s+)", r"\1- **\2**\3", blocks[i])
|
||||
blocks[i] = re.sub(r":\s*\n\s*(\S)", r" -- \1", blocks[i])
|
||||
|
||||
return "\n".join(blocks)
|
||||
|
||||
|
||||
def _prepare_output_docstrings(output_type, config_class, min_indent=None, add_intro=True):
|
||||
"""
|
||||
Prepares the return part of the docstring using `output_type`.
|
||||
"""
|
||||
output_docstring = output_type.__doc__
|
||||
params_docstring = None
|
||||
if output_docstring is not None:
|
||||
# Remove the head of the docstring to keep the list of args only
|
||||
lines = output_docstring.split("\n")
|
||||
i = 0
|
||||
while i < len(lines) and re.search(r"^\s*(Args|Parameters):\s*$", lines[i]) is None:
|
||||
i += 1
|
||||
if i < len(lines):
|
||||
params_docstring = "\n".join(lines[(i + 1) :])
|
||||
params_docstring = _convert_output_args_doc(params_docstring)
|
||||
elif add_intro:
|
||||
raise ValueError(
|
||||
f"No `Args` or `Parameters` section is found in the docstring of `{output_type.__name__}`. Make sure it has "
|
||||
"docstring and contain either `Args` or `Parameters`."
|
||||
)
|
||||
|
||||
# Add the return introduction
|
||||
if add_intro:
|
||||
full_output_type = f"{output_type.__module__}.{output_type.__name__}"
|
||||
intro = PT_RETURN_INTRODUCTION
|
||||
intro = intro.format(full_output_type=full_output_type, config_class=config_class)
|
||||
else:
|
||||
full_output_type = str(output_type)
|
||||
intro = f"\nReturns:\n `{full_output_type}`"
|
||||
if params_docstring is not None:
|
||||
intro += ":\n"
|
||||
|
||||
result = intro
|
||||
if params_docstring is not None:
|
||||
result += params_docstring
|
||||
|
||||
# Apply minimum indent if necessary
|
||||
if min_indent is not None:
|
||||
lines = result.split("\n")
|
||||
# Find the indent of the first nonempty line
|
||||
i = 0
|
||||
while len(lines[i]) == 0:
|
||||
i += 1
|
||||
indent = len(_get_indent(lines[i]))
|
||||
# If too small, add indentation to all nonempty lines
|
||||
if indent < min_indent:
|
||||
to_add = " " * (min_indent - indent)
|
||||
lines = [(f"{to_add}{line}" if len(line) > 0 else line) for line in lines]
|
||||
result = "\n".join(lines)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
FAKE_MODEL_DISCLAIMER = """
|
||||
<Tip warning={true}>
|
||||
|
||||
This example uses a random model as the real ones are all very big. To get proper results, you should use
|
||||
{real_checkpoint} instead of {fake_checkpoint}. If you get out-of-memory when loading that checkpoint, you can
|
||||
refer to our optimization docs.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
|
||||
|
||||
PT_TEXT_TO_IMAGE_SAMPLE = r"""
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from diffusers import DiffusionPipeline
|
||||
>>> import torch
|
||||
|
||||
>>> # If memory doesn't allow, enable optimizations like `enable_model_cpu_offload()`.
|
||||
>>> pipe = DiffusionPipeline.from_pretrained("{checkpoint}", torch_dtype=torch.bfloat16).to("cuda")
|
||||
|
||||
>>> prompt = "a photo of a cute dog."
|
||||
>>> image = pipe(prompt).images[0] # Configure other pipe call arguments as needed.
|
||||
```
|
||||
"""
|
||||
|
||||
PT_SAMPLE_DOCSTRINGS = {
|
||||
"Text2Image": PT_TEXT_TO_IMAGE_SAMPLE
|
||||
}
|
||||
PIPELINE_TASKS_TO_SAMPLE_DOCSTRINGS = OrderedDict(["text-to-image", PT_TEXT_TO_IMAGE_SAMPLE])
|
||||
|
||||
def filter_outputs_from_example(docstring, **kwargs):
|
||||
"""
|
||||
Removes the lines testing an output with the doctest syntax in a code sample when it's set to `None`.
|
||||
"""
|
||||
for key, value in kwargs.items():
|
||||
if value is not None:
|
||||
continue
|
||||
|
||||
doc_key = "{" + key + "}"
|
||||
docstring = re.sub(rf"\n([^\n]+)\n\s+{doc_key}\n", "\n", docstring)
|
||||
|
||||
return docstring
|
||||
|
||||
|
||||
def add_code_sample_docstrings(
|
||||
*docstr,
|
||||
checkpoint=None,
|
||||
output_type=None,
|
||||
config_class=None,
|
||||
model_cls=None,
|
||||
):
|
||||
def docstring_decorator(fn):
|
||||
# model_class defaults to function's class if not specified otherwise
|
||||
model_class = fn.__qualname__.split(".")[0] if model_cls is None else model_cls
|
||||
|
||||
sample_docstrings = PT_SAMPLE_DOCSTRINGS
|
||||
|
||||
# putting all kwargs for docstrings in a dict to be used
|
||||
# with the `.format(**doc_kwargs)`. Note that string might
|
||||
# be formatted with non-existing keys, which is fine.
|
||||
doc_kwargs = {
|
||||
"checkpoint": checkpoint,
|
||||
"true": "{true}", # For <Tip warning={true}> syntax that conflicts with formatting.
|
||||
}
|
||||
|
||||
if model_class in TEXT_TO_IMAGE_PIPELINE_CLASSES:
|
||||
code_sample = sample_docstrings["Text2Image"]
|
||||
else:
|
||||
raise ValueError(f"Docstring can't be built for model {model_class}")
|
||||
|
||||
code_sample = filter_outputs_from_example(code_sample)
|
||||
func_doc = (fn.__doc__ or "") + "".join(docstr)
|
||||
output_doc = "" if output_type is None else _prepare_output_docstrings(output_type, config_class)
|
||||
built_doc = code_sample.format(**doc_kwargs)
|
||||
|
||||
fn.__doc__ = func_doc + output_doc + built_doc
|
||||
return fn
|
||||
|
||||
return docstring_decorator
|
||||
|
||||
|
||||
def replace_return_docstrings(output_type=None, config_class=None):
|
||||
def docstring_decorator(fn):
|
||||
func_doc = fn.__doc__
|
||||
lines = func_doc.split("\n")
|
||||
i = 0
|
||||
while i < len(lines) and re.search(r"^\s*Returns?:\s*$", lines[i]) is None:
|
||||
i += 1
|
||||
if i < len(lines):
|
||||
indent = len(_get_indent(lines[i]))
|
||||
lines[i] = _prepare_output_docstrings(output_type, config_class, min_indent=indent)
|
||||
func_doc = "\n".join(lines)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The function {fn} should have an empty 'Return:' or 'Returns:' in its docstring as placeholder, "
|
||||
f"current docstring is:\n{func_doc}"
|
||||
)
|
||||
fn.__doc__ = func_doc
|
||||
return fn
|
||||
|
||||
return docstring_decorator
|
||||
|
||||
|
||||
def copy_func(f):
|
||||
"""Returns a copy of a function f."""
|
||||
# Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard)
|
||||
g = types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__, closure=f.__closure__)
|
||||
g = functools.update_wrapper(g, f)
|
||||
g.__kwdefaults__ = f.__kwdefaults__
|
||||
return g
|
||||
@@ -325,6 +325,21 @@ class CacheMixin(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class ChromaTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class CogVideoXTransformer3DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -1150,6 +1165,21 @@ class WanTransformer3DModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class WanVACETransformer3DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
def get_constant_schedule(*args, **kwargs):
|
||||
requires_backends(get_constant_schedule, ["torch"])
|
||||
|
||||
|
||||
@@ -272,6 +272,21 @@ class AuraFlowPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class ChromaPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class CLIPImageProjection(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
@@ -407,6 +422,36 @@ class ConsisIDPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Cosmos2TextToImagePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Cosmos2VideoToWorldPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class CosmosTextToWorldPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
@@ -2897,6 +2942,21 @@ class WanPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class WanVACEPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class WanVideoToVideoPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -154,12 +154,30 @@ def check_imports(filename):
|
||||
return get_relative_imports(filename)
|
||||
|
||||
|
||||
def get_class_in_module(class_name, module_path):
|
||||
def get_class_in_module(class_name, module_path, pretrained_model_name_or_path=None):
|
||||
"""
|
||||
Import a module on the cache directory for modules and extract a class from it.
|
||||
"""
|
||||
module_path = module_path.replace(os.path.sep, ".")
|
||||
module = importlib.import_module(module_path)
|
||||
try:
|
||||
module = importlib.import_module(module_path)
|
||||
except ModuleNotFoundError as e:
|
||||
# This can happen when the repo id contains ".", which Python's import machinery interprets as a directory
|
||||
# separator. We do a bit of monkey patching to detect and fix this case.
|
||||
if not (
|
||||
pretrained_model_name_or_path is not None
|
||||
and "." in pretrained_model_name_or_path
|
||||
and module_path.startswith("diffusers_modules")
|
||||
and pretrained_model_name_or_path.replace("/", "--") in module_path
|
||||
):
|
||||
raise e # We can't figure this one out, just reraise the original error
|
||||
|
||||
corrected_path = os.path.join(HF_MODULES_CACHE, module_path.replace(".", "/")) + ".py"
|
||||
corrected_path = corrected_path.replace(
|
||||
pretrained_model_name_or_path.replace("/", "--").replace(".", "/"),
|
||||
pretrained_model_name_or_path.replace("/", "--"),
|
||||
)
|
||||
module = importlib.machinery.SourceFileLoader(module_path, corrected_path).load_module()
|
||||
|
||||
if class_name is None:
|
||||
return find_pipeline_class(module)
|
||||
@@ -454,4 +472,4 @@ def get_class_from_dynamic_module(
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
return get_class_in_module(class_name, final_module.replace(".py", ""))
|
||||
return get_class_in_module(class_name, final_module.replace(".py", ""), pretrained_model_name_or_path)
|
||||
|
||||
@@ -99,6 +99,7 @@ if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VA
|
||||
else:
|
||||
logger.info("Disabling PyTorch because USE_TORCH is set")
|
||||
_torch_available = False
|
||||
_torch_version = "N/A"
|
||||
|
||||
_jax_version = "N/A"
|
||||
_flax_version = "N/A"
|
||||
|
||||
@@ -16,6 +16,7 @@ State dict utilities: utility methods for converting state dicts easily
|
||||
"""
|
||||
|
||||
import enum
|
||||
import json
|
||||
|
||||
from .import_utils import is_torch_available
|
||||
from .logging import get_logger
|
||||
@@ -347,3 +348,19 @@ def state_dict_all_zero(state_dict, filter_str=None):
|
||||
state_dict = {k: v for k, v in state_dict.items() if any(f in k for f in filter_str)}
|
||||
|
||||
return all(torch.all(param == 0).item() for param in state_dict.values())
|
||||
|
||||
|
||||
def _load_sft_state_dict_metadata(model_file: str):
|
||||
import safetensors.torch
|
||||
|
||||
from ..loaders.lora_base import LORA_ADAPTER_METADATA_KEY
|
||||
|
||||
with safetensors.torch.safe_open(model_file, framework="pt", device="cpu") as f:
|
||||
metadata = f.metadata() or {}
|
||||
|
||||
metadata.pop("format", None)
|
||||
if metadata:
|
||||
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
|
||||
return json.loads(raw) if raw else None
|
||||
else:
|
||||
return None
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user