Compare commits

..

7 Commits

Author SHA1 Message Date
sayakpaul b7f414f719 more combos. 2025-06-18 07:10:34 +05:30
Sayak Paul 222ed1500d Merge branch 'main' into tip-compile-offload 2025-06-17 16:32:17 +05:30
sayakpaul edef2da4e4 add a tip for compile + offload 2025-06-17 16:30:24 +05:30
Linoy Tsaban 1bc6f3dc0f [LoRA training] update metadata use for lora alpha + README (#11723)
* lora alpha

* Apply style fixes

* Update examples/advanced_diffusion_training/README_flux.md

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* fix readme format

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-06-17 12:19:27 +03:00
Aryan 79bd7ecc78 Support more Wan loras (VACE) (#11726)
update
2025-06-17 10:39:18 +05:30
David Berenstein 9b834f8710 Add Pruna optimization framework documentation (#11688)
* Add Pruna optimization framework documentation

- Introduced a new section for Pruna in the table of contents.
- Added comprehensive documentation for Pruna, detailing its optimization techniques, installation instructions, and examples for optimizing and evaluating models

* Enhance Pruna documentation with image alt text and code block formatting

- Added alt text to images for better accessibility and context.
- Changed code block syntax from diff to python for improved clarity.

* Add installation section to Pruna documentation

- Introduced a new installation section in the Pruna documentation to guide users on how to install the framework.
- Enhanced the overall clarity and usability of the documentation for new users.

* Update pruna.md

* Update pruna.md

* Update Pruna documentation for model optimization and evaluation

- Changed section titles for consistency and clarity, from "Optimizing models" to "Optimize models" and "Evaluating and benchmarking optimized models" to "Evaluate and benchmark models".
- Enhanced descriptions to clarify the use of `diffusers` models and the evaluation process.
- Added a new example for evaluating standalone `diffusers` models.
- Updated references and links for better navigation within the documentation.

* Refactor Pruna documentation for clarity and consistency

- Removed outdated references to FLUX-juiced and streamlined the explanation of benchmarking.
- Enhanced the description of evaluating standalone `diffusers` models.
- Cleaned up code examples by removing unnecessary imports and comments for better readability.

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Enhance Pruna documentation with new examples and clarifications

- Added an image to illustrate the optimization process.
- Updated the explanation for sharing and loading optimized models on the Hugging Face Hub.
- Clarified the evaluation process for optimized models using the EvaluationAgent.
- Improved descriptions for defining metrics and evaluating standalone diffusers models.

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2025-06-16 12:25:05 -07:00
Carl Thomé 81426b0f19 Fix misleading comment (#11722) 2025-06-16 08:47:00 -10:00
18 changed files with 381 additions and 1396 deletions
+2
View File
@@ -180,6 +180,8 @@
title: Caching
- local: optimization/memory
title: Reduce memory usage
- local: optimization/pruna
title: Pruna
- local: optimization/xformers
title: xFormers
- local: optimization/tome
+13
View File
@@ -302,6 +302,13 @@ compute-bound, [group-offloading](#group-offloading) tends to be better. Group o
</Tip>
<Tip>
When using offloading, users can additionally compile the diffusion transformer/unet to get a
good speed-memory trade-off. First set `torch._dynamo.config.cache_size_limit=1000`, and then before calling the pipeline, add `pipeline.transformer.compile()`.
</Tip>
## Layerwise casting
Layerwise casting stores weights in a smaller data format (for example, `torch.float8_e4m3fn` and `torch.float8_e5m2`) to use less memory and upcasts those weights to a higher precision like `torch.float16` or `torch.bfloat16` for computation. Certain layers (normalization and modulation related weights) are skipped because storing them in fp8 can degrade generation quality.
@@ -365,6 +372,12 @@ apply_layerwise_casting(
)
```
<Tip>
Layerwise casting can be combined with group offloading.
</Tip>
## torch.channels_last
[torch.channels_last](https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html) flips how tensors are stored from `(batch size, channels, height, width)` to `(batch size, heigh, width, channels)`. This aligns the tensors with how the hardware sequentially accesses the tensors stored in memory and avoids skipping around in memory to access the pixel values.
+187
View File
@@ -0,0 +1,187 @@
# Pruna
[Pruna](https://github.com/PrunaAI/pruna) is a model optimization framework that offers various optimization methods - quantization, pruning, caching, compilation - for accelerating inference and reducing memory usage. A general overview of the optimization methods are shown below.
| Technique | Description | Speed | Memory | Quality |
|--------------|-----------------------------------------------------------------------------------------------|:-----:|:------:|:-------:|
| `batcher` | Groups multiple inputs together to be processed simultaneously, improving computational efficiency and reducing processing time. | ✅ | ❌ | |
| `cacher` | Stores intermediate results of computations to speed up subsequent operations. | ✅ | | |
| `compiler` | Optimises the model with instructions for specific hardware. | ✅ | | |
| `distiller` | Trains a smaller, simpler model to mimic a larger, more complex model. | ✅ | ✅ | ❌ |
| `quantizer` | Reduces the precision of weights and activations, lowering memory requirements. | ✅ | ✅ | ❌ |
| `pruner` | Removes less important or redundant connections and neurons, resulting in a sparser, more efficient network. | ✅ | ✅ | ❌ |
| `recoverer` | Restores the performance of a model after compression. | ➖ | | ✅ |
| `factorizer` | Factorization batches several small matrix multiplications into one large fused operation. | ✅ | | |
| `enhancer` | Enhances the model output by applying post-processing algorithms such as denoising or upscaling. | ❌ | - | ✅ |
✅ (improves), (approx. the same), ❌ (worsens)
Explore the full range of optimization methods in the [Pruna documentation](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/configure.html#configure-algorithms).
## Installation
Install Pruna with the following command.
```bash
pip install pruna
```
## Optimize Diffusers models
A broad range of optimization algorithms are supported for Diffusers models as shown below.
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/PrunaAI/documentation-images/resolve/main/diffusers/diffusers_combinations.png" alt="Overview of the supported optimization algorithms for diffusers models">
</div>
The example below optimizes [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)
with a combination of factorizer, compiler, and cacher algorithms. This combination accelerates inference by up to 4.2x and cuts peak GPU memory usage from 34.7GB to 28.0GB, all while maintaining virtually the same output quality.
> [!TIP]
> Refer to the [Pruna optimization](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/configure.html) docs to learn more about the optimization techniques used in this example.
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/PrunaAI/documentation-images/resolve/main/diffusers/flux_combination.png" alt="Optimization techniques used for FLUX.1-dev showing the combination of factorizer, compiler, and cacher algorithms">
</div>
Start by defining a `SmashConfig` with the optimization algorithms to use. To optimize the model, wrap the pipeline and the `SmashConfig` with `smash` and then use the pipeline as normal for inference.
```python
import torch
from diffusers import FluxPipeline
from pruna import PrunaModel, SmashConfig, smash
# load the model
# Try segmind/Segmind-Vega or black-forest-labs/FLUX.1-schnell with a small GPU memory
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16
).to("cuda")
# define the configuration
smash_config = SmashConfig()
smash_config["factorizer"] = "qkv_diffusers"
smash_config["compiler"] = "torch_compile"
smash_config["torch_compile_target"] = "module_list"
smash_config["cacher"] = "fora"
smash_config["fora_interval"] = 2
# for the best results in terms of speed you can add these configs
# however they will increase your warmup time from 1.5 min to 10 min
# smash_config["torch_compile_mode"] = "max-autotune-no-cudagraphs"
# smash_config["quantizer"] = "torchao"
# smash_config["torchao_quant_type"] = "fp8dq"
# smash_config["torchao_excluded_modules"] = "norm+embedding"
# optimize the model
smashed_pipe = smash(pipe, smash_config)
# run the model
smashed_pipe("a knitted purple prune").images[0]
```
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/PrunaAI/documentation-images/resolve/main/diffusers/flux_smashed_comparison.png">
</div>
After optimization, we can share and load the optimized model using the Hugging Face Hub.
```python
# save the model
smashed_pipe.save_to_hub("<username>/FLUX.1-dev-smashed")
# load the model
smashed_pipe = PrunaModel.from_hub("<username>/FLUX.1-dev-smashed")
```
## Evaluate and benchmark Diffusers models
Pruna provides the [EvaluationAgent](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/evaluate.html) to evaluate the quality of your optimized models.
We can metrics we care about, such as total time and throughput, and the dataset to evaluate on. We can define a model and pass it to the `EvaluationAgent`.
<hfoptions id="eval">
<hfoption id="optimized model">
We can load and evaluate an optimized model by using the `EvaluationAgent` and pass it to the `Task`.
```python
import torch
from diffusers import FluxPipeline
from pruna import PrunaModel
from pruna.data.pruna_datamodule import PrunaDataModule
from pruna.evaluation.evaluation_agent import EvaluationAgent
from pruna.evaluation.metrics import (
ThroughputMetric,
TorchMetricWrapper,
TotalTimeMetric,
)
from pruna.evaluation.task import Task
# define the device
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
# load the model
# Try PrunaAI/Segmind-Vega-smashed or PrunaAI/FLUX.1-dev-smashed with a small GPU memory
smashed_pipe = PrunaModel.from_hub("PrunaAI/FLUX.1-dev-smashed")
# Define the metrics
metrics = [
TotalTimeMetric(n_iterations=20, n_warmup_iterations=5),
ThroughputMetric(n_iterations=20, n_warmup_iterations=5),
TorchMetricWrapper("clip"),
]
# Define the datamodule
datamodule = PrunaDataModule.from_string("LAION256")
datamodule.limit_datasets(10)
# Define the task and evaluation agent
task = Task(metrics, datamodule=datamodule, device=device)
eval_agent = EvaluationAgent(task)
# Evaluate smashed model and offload it to CPU
smashed_pipe.move_to_device(device)
smashed_pipe_results = eval_agent.evaluate(smashed_pipe)
smashed_pipe.move_to_device("cpu")
```
</hfoption>
<hfoption id="standalone model">
Instead of comparing the optimized model to the base model, you can also evaluate the standalone `diffusers` model. This is useful if you want to evaluate the performance of the model without the optimization. We can do so by using the `PrunaModel` wrapper and run the `EvaluationAgent` on it.
```python
import torch
from diffusers import FluxPipeline
from pruna import PrunaModel
# load the model
# Try PrunaAI/Segmind-Vega-smashed or PrunaAI/FLUX.1-dev-smashed with a small GPU memory
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16
).to("cpu")
wrapped_pipe = PrunaModel(model=pipe)
```
</hfoption>
</hfoptions>
Now that you have seen how to optimize and evaluate your models, you can start using Pruna to optimize your own models. Luckily, we have many examples to help you get started.
> [!TIP]
> For more details about benchmarking Flux, check out the [Announcing FLUX-Juiced: The Fastest Image Generation Endpoint (2.6 times faster)!](https://huggingface.co/blog/PrunaAI/flux-fastest-image-generation-endpoint) blog post and the [InferBench](https://huggingface.co/spaces/PrunaAI/InferBench) Space.
## Reference
- [Pruna](https://github.com/pruna-ai/pruna)
- [Pruna optimization](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/configure.html#configure-algorithms)
- [Pruna evaluation](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/evaluate.html)
- [Pruna tutorials](https://docs.pruna.ai/en/stable/docs_pruna/tutorials/index.html)
@@ -76,6 +76,24 @@ This command will prompt you for a token. Copy-paste yours from your [settings/t
> `pip install wandb`
> Alternatively, you can use other tools / train without reporting by modifying the flag `--report_to="wandb"`.
### LoRA Rank and Alpha
Two key LoRA hyperparameters are LoRA rank and LoRA alpha.
- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters).
- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by lora_alpha / lora_rank.
- lora_alpha vs. rank:
This ratio dictates the LoRA's effective strength:
lora_alpha == rank: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16)
lora_alpha < rank: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16)
lora_alpha > rank: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16)
> [!TIP]
> A common starting point is to set `lora_alpha` equal to `rank`.
> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16)
> to give the LoRA updates more influence without increasing parameter count.
> If you find your LoRA is "overcooking" or learning too aggressively, consider setting `lora_alpha` to half of `rank`
> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case.
### Target Modules
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import os
import sys
@@ -20,6 +21,8 @@ import tempfile
import safetensors
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
@@ -281,3 +284,45 @@ class DreamBoothLoRAFluxAdvanced(ExamplesTestsAccelerate):
run_command(self._launch_args + resume_run_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
def test_dreambooth_lora_with_metadata(self):
# Use a `lora_alpha` that is different from `rank`.
lora_alpha = 8
rank = 4
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--lora_alpha={lora_alpha}
--rank={rank}
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
self.assertTrue(os.path.isfile(state_dict_file))
# Check if the metadata was properly serialized.
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
metadata = f.metadata() or {}
metadata.pop("format", None)
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
if raw:
raw = json.loads(raw)
loaded_lora_alpha = raw["transformer.lora_alpha"]
self.assertTrue(loaded_lora_alpha == lora_alpha)
loaded_lora_rank = raw["transformer.r"]
self.assertTrue(loaded_lora_rank == rank)
@@ -55,6 +55,7 @@ from diffusers import (
)
from diffusers.optimization import get_scheduler
from diffusers.training_utils import (
_collate_lora_metadata,
_set_state_dict_into_text_encoder,
cast_training_params,
compute_density_for_timestep_sampling,
@@ -431,6 +432,13 @@ def parse_args(input_args=None):
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument(
"--lora_alpha",
type=int,
default=4,
help="LoRA alpha to be used for additional scaling.",
)
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
parser.add_argument(
@@ -1556,7 +1564,7 @@ def main(args):
# now we will add new LoRA weights to the attention layers
transformer_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
init_lora_weights="gaussian",
target_modules=target_modules,
@@ -1565,7 +1573,7 @@ def main(args):
if args.train_text_encoder:
text_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
@@ -1582,13 +1590,15 @@ def main(args):
if accelerator.is_main_process:
transformer_lora_layers_to_save = None
text_encoder_one_lora_layers_to_save = None
modules_to_save = {}
for model in models:
if isinstance(model, type(unwrap_model(transformer))):
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
modules_to_save["transformer"] = model
elif isinstance(model, type(unwrap_model(text_encoder_one))):
if args.train_text_encoder: # when --train_text_encoder_ti we don't save the layers
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
modules_to_save["text_encoder"] = model
elif isinstance(model, type(unwrap_model(text_encoder_two))):
pass # when --train_text_encoder_ti and --enable_t5_ti we don't save the layers
else:
@@ -1601,6 +1611,7 @@ def main(args):
output_dir,
transformer_lora_layers=transformer_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
**_collate_lora_metadata(modules_to_save),
)
if args.train_text_encoder_ti:
embedding_handler.save_embeddings(f"{args.output_dir}/{Path(args.output_dir).name}_emb.safetensors")
@@ -2359,16 +2370,19 @@ def main(args):
# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
modules_to_save = {}
transformer = unwrap_model(transformer)
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)
modules_to_save["transformer"] = transformer
if args.train_text_encoder:
text_encoder_one = unwrap_model(text_encoder_one)
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
modules_to_save["text_encoder"] = text_encoder_one
else:
text_encoder_lora_layers = None
@@ -2377,6 +2391,7 @@ def main(args):
save_directory=args.output_dir,
transformer_lora_layers=transformer_lora_layers,
text_encoder_lora_layers=text_encoder_lora_layers,
**_collate_lora_metadata(modules_to_save),
)
if args.train_text_encoder_ti:
+17
View File
@@ -170,6 +170,23 @@ accelerate launch train_dreambooth_lora_flux.py \
--push_to_hub
```
### LoRA Rank and Alpha
Two key LoRA hyperparameters are LoRA rank and LoRA alpha.
- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters).
- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by lora_alpha / lora_rank.
- lora_alpha vs. rank:
This ratio dictates the LoRA's effective strength:
lora_alpha == rank: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16)
lora_alpha < rank: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16)
lora_alpha > rank: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16)
> [!TIP]
> A common starting point is to set `lora_alpha` equal to `rank`.
> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16)
> to give the LoRA updates more influence without increasing parameter count.
> If you find your LoRA is "overcooking" or learning too aggressively, consider setting `lora_alpha` to half of `rank`
> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case.
### Target Modules
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
+1 -1
View File
@@ -1,4 +1,4 @@
# Run this script to convert the Stable Cascade model weights to a diffusers pipeline.
# Run this script to convert the Stable Audio model weights to a diffusers pipeline.
import argparse
import json
import os
-2
View File
@@ -353,7 +353,6 @@ else:
"AuraFlowPipeline",
"BlipDiffusionControlNetPipeline",
"BlipDiffusionPipeline",
"ChromaImg2ImgPipeline",
"ChromaPipeline",
"CLIPImageProjection",
"CogVideoXFunControlPipeline",
@@ -946,7 +945,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AudioLDM2UNet2DConditionModel,
AudioLDMPipeline,
AuraFlowPipeline,
ChromaImg2ImgPipeline,
ChromaPipeline,
CLIPImageProjection,
CogVideoXFunControlPipeline,
+51 -36
View File
@@ -1596,7 +1596,10 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
converted_state_dict = {}
original_state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()}
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in original_state_dict if "blocks." in k})
block_numbers = {int(k.split(".")[1]) for k in original_state_dict if k.startswith("blocks.")}
min_block = min(block_numbers)
max_block = max(block_numbers)
is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
lora_down_key = "lora_A" if any("lora_A" in k for k in original_state_dict) else "lora_down"
lora_up_key = "lora_B" if any("lora_B" in k for k in original_state_dict) else "lora_up"
@@ -1622,45 +1625,57 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
# For the `diff_b` keys, we treat them as lora_bias.
# https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.lora_bias
for i in range(num_blocks):
for i in range(min_block, max_block + 1):
# Self-attention
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_A.weight"] = original_state_dict.pop(
f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight"
)
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = original_state_dict.pop(
f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight"
)
if f"blocks.{i}.self_attn.{o}.diff_b" in original_state_dict:
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.bias"] = original_state_dict.pop(
f"blocks.{i}.self_attn.{o}.diff_b"
)
original_key = f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight"
converted_key = f"blocks.{i}.attn1.{c}.lora_A.weight"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
original_key = f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight"
converted_key = f"blocks.{i}.attn1.{c}.lora_B.weight"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
original_key = f"blocks.{i}.self_attn.{o}.diff_b"
converted_key = f"blocks.{i}.attn1.{c}.lora_B.bias"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
# Cross-attention
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
)
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
)
if f"blocks.{i}.cross_attn.{o}.diff_b" in original_state_dict:
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.bias"] = original_state_dict.pop(
f"blocks.{i}.cross_attn.{o}.diff_b"
)
original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
if is_i2v_lora:
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
)
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
)
if f"blocks.{i}.cross_attn.{o}.diff_b" in original_state_dict:
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.bias"] = original_state_dict.pop(
f"blocks.{i}.cross_attn.{o}.diff_b"
)
original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
# FFN
for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
@@ -1674,10 +1689,10 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
if f"blocks.{i}.{o}.diff_b" in original_state_dict:
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.bias"] = original_state_dict.pop(
f"blocks.{i}.{o}.diff_b"
)
original_key = f"blocks.{i}.{o}.diff_b"
converted_key = f"blocks.{i}.ffn.{c}.lora_B.bias"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
# Remaining.
if original_state_dict:
+2 -6
View File
@@ -2543,9 +2543,7 @@ class FusedFluxAttnProcessor2_0:
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
@@ -2778,9 +2776,7 @@ class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module):
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
@@ -250,21 +250,15 @@ class ChromaSingleTransformerBlock(nn.Module):
hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
joint_attention_kwargs = joint_attention_kwargs or {}
if attention_mask is not None:
attention_mask = attention_mask[:, None, None, :] * attention_mask[:, None, :, None]
attn_output = self.attn(
hidden_states=norm_hidden_states,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
**joint_attention_kwargs,
)
@@ -318,7 +312,6 @@ class ChromaTransformerBlock(nn.Module):
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
temb_img, temb_txt = temb[:, :6], temb[:, 6:]
@@ -328,15 +321,11 @@ class ChromaTransformerBlock(nn.Module):
encoder_hidden_states, emb=temb_txt
)
joint_attention_kwargs = joint_attention_kwargs or {}
if attention_mask is not None:
attention_mask = attention_mask[:, None, None, :] * attention_mask[:, None, :, None]
# Attention.
attention_outputs = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
**joint_attention_kwargs,
)
@@ -581,7 +570,6 @@ class ChromaTransformer2DModel(
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
attention_mask: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
@@ -671,7 +659,11 @@ class ChromaTransformer2DModel(
)
if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
)
else:
@@ -680,7 +672,6 @@ class ChromaTransformer2DModel(
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
joint_attention_kwargs=joint_attention_kwargs,
)
@@ -713,7 +704,6 @@ class ChromaTransformer2DModel(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
joint_attention_kwargs=joint_attention_kwargs,
)
+2 -2
View File
@@ -148,7 +148,7 @@ else:
"AudioLDM2UNet2DConditionModel",
]
_import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"]
_import_structure["chroma"] = ["ChromaPipeline", "ChromaImg2ImgPipeline"]
_import_structure["chroma"] = ["ChromaPipeline"]
_import_structure["cogvideo"] = [
"CogVideoXPipeline",
"CogVideoXImageToVideoPipeline",
@@ -537,7 +537,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
)
from .aura_flow import AuraFlowPipeline
from .blip_diffusion import BlipDiffusionPipeline
from .chroma import ChromaImg2ImgPipeline, ChromaPipeline
from .chroma import ChromaPipeline
from .cogvideo import (
CogVideoXFunControlPipeline,
CogVideoXImageToVideoPipeline,
@@ -23,7 +23,6 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_chroma"] = ["ChromaPipeline"]
_import_structure["pipeline_chroma_img2img"] = ["ChromaImg2ImgPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
@@ -32,7 +31,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .pipeline_chroma import ChromaPipeline
from .pipeline_chroma_img2img import ChromaImg2ImgPipeline
else:
import sys
+20 -105
View File
@@ -1,4 +1,4 @@
# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -52,21 +52,12 @@ EXAMPLE_DOC_STRING = """
>>> import torch
>>> from diffusers import ChromaPipeline
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors"
>>> transformer = ChromaTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
>>> text_encoder = AutoModel.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="text_encoder_2")
>>> tokenizer = AutoTokenizer.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="tokenizer_2")
>>> pipe = ChromaImg2ImgPipeline.from_pretrained(
... "black-forest-labs/FLUX.1-schnell",
... transformer=transformer,
... text_encoder=text_encoder,
... tokenizer=tokenizer,
... torch_dtype=torch.bfloat16,
>>> pipe = ChromaPipeline.from_single_file(
... "chroma-unlocked-v35-detail-calibrated.safetensors", torch_dtype=torch.bfloat16
... )
>>> pipe.enable_model_cpu_offload()
>>> pipe.to("cuda")
>>> prompt = "A cat holding a sign that says hello world"
>>> negative_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
>>> image = pipe(prompt, negative_prompt=negative_prompt).images[0]
>>> image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0]
>>> image.save("chroma.png")
```
"""
@@ -244,7 +235,6 @@ class ChromaPipeline(
dtype = self.text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
attention_mask = attention_mask.to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
@@ -252,10 +242,7 @@ class ChromaPipeline(
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
attention_mask = attention_mask.repeat(1, num_images_per_prompt)
attention_mask = attention_mask.view(batch_size * num_images_per_prompt, seq_len)
return prompt_embeds, attention_mask
return prompt_embeds
def encode_prompt(
self,
@@ -263,10 +250,8 @@ class ChromaPipeline(
negative_prompt: Union[str, List[str]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
do_classifier_free_guidance: bool = True,
max_sequence_length: int = 512,
lora_scale: Optional[float] = None,
@@ -283,7 +268,7 @@ class ChromaPipeline(
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
prompt_embeds (`torch.Tensor`, *optional*):
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
lora_scale (`float`, *optional*):
@@ -308,7 +293,7 @@ class ChromaPipeline(
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
prompt_embeds = self._get_t5_prompt_embeds(
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
@@ -338,13 +323,12 @@ class ChromaPipeline(
" the batch size of `prompt`."
)
negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
negative_prompt_embeds = self._get_t5_prompt_embeds(
prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
negative_text_ids = torch.zeros(negative_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
if self.text_encoder is not None:
@@ -352,14 +336,7 @@ class ChromaPipeline(
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
return (
prompt_embeds,
text_ids,
prompt_attention_mask,
negative_prompt_embeds,
negative_text_ids,
negative_prompt_attention_mask,
)
return prompt_embeds, text_ids, negative_prompt_embeds, negative_text_ids
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt):
@@ -417,9 +394,7 @@ class ChromaPipeline(
width,
negative_prompt=None,
prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_embeds=None,
negative_prompt_attention_mask=None,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
@@ -453,14 +428,6 @@ class ChromaPipeline(
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and prompt_attention_mask is None:
raise ValueError("Cannot provide `prompt_embeds` without also providing `prompt_attention_mask")
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
raise ValueError(
"Cannot provide `negative_prompt_embeds` without also providing `negative_prompt_attention_mask"
)
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
@@ -567,25 +534,6 @@ class ChromaPipeline(
return latents, latent_image_ids
def _prepare_attention_mask(
self,
batch_size,
sequence_length,
dtype,
attention_mask=None,
):
if attention_mask is None:
return attention_mask
# Extend the prompt attention mask to account for image tokens in the final sequence
attention_mask = torch.cat(
[attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device)],
dim=1,
)
attention_mask = attention_mask.to(dtype)
return attention_mask
@property
def guidance_scale(self):
return self._guidance_scale
@@ -618,20 +566,18 @@ class ChromaPipeline(
negative_prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 35,
num_inference_steps: int = 28,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 5.0,
guidance_scale: float = 3.5,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -672,11 +618,11 @@ class ChromaPipeline(
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
@@ -690,18 +636,10 @@ class ChromaPipeline(
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
prompt_attention_mask (torch.Tensor, *optional*):
Attention mask for the prompt embeddings. Used to mask out padding tokens in the prompt sequence.
Chroma requires a single padding token remain unmasked. Please refer to
https://huggingface.co/lodestones/Chroma#tldr-masking-t5-padding-tokens-enhanced-fidelity-and-increased-stability-during-training
negative_prompt_attention_mask (torch.Tensor, *optional*):
Attention mask for the negative prompt embeddings. Used to mask out padding tokens in the negative
prompt sequence. Chroma requires a single padding token remain unmasked. PLease refer to
https://huggingface.co/lodestones/Chroma#tldr-masking-t5-padding-tokens-enhanced-fidelity-and-increased-stability-during-training
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -740,9 +678,7 @@ class ChromaPipeline(
width,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_attention_mask=negative_prompt_attention_mask,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
@@ -768,17 +704,13 @@ class ChromaPipeline(
(
prompt_embeds,
text_ids,
prompt_attention_mask,
negative_prompt_embeds,
negative_text_ids,
negative_prompt_attention_mask,
) = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
do_classifier_free_guidance=self.do_classifier_free_guidance,
device=device,
num_images_per_prompt=num_images_per_prompt,
@@ -798,7 +730,6 @@ class ChromaPipeline(
generator,
latents,
)
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
image_seq_len = latents.shape[1]
@@ -809,20 +740,6 @@ class ChromaPipeline(
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
attention_mask = self._prepare_attention_mask(
batch_size=latents.shape[0],
sequence_length=image_seq_len,
dtype=latents.dtype,
attention_mask=prompt_attention_mask,
)
negative_attention_mask = self._prepare_attention_mask(
batch_size=latents.shape[0],
sequence_length=image_seq_len,
dtype=latents.dtype,
attention_mask=negative_prompt_attention_mask,
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
@@ -884,7 +801,6 @@ class ChromaPipeline(
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
attention_mask=attention_mask,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
@@ -898,7 +814,6 @@ class ChromaPipeline(
encoder_hidden_states=negative_prompt_embeds,
txt_ids=negative_text_ids,
img_ids=latent_image_ids,
attention_mask=negative_attention_mask,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
File diff suppressed because it is too large Load Diff
@@ -272,21 +272,6 @@ class AuraFlowPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class ChromaImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class ChromaPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -1,170 +0,0 @@
import random
import unittest
import numpy as np
import torch
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, ChromaImg2ImgPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler
from diffusers.utils.testing_utils import floats_tensor, torch_device
from ..test_pipelines_common import (
FluxIPAdapterTesterMixin,
PipelineTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
check_qkv_fusion_processors_exist,
)
class ChromaImg2ImgPipelineFastTests(
unittest.TestCase,
PipelineTesterMixin,
FluxIPAdapterTesterMixin,
):
pipeline_class = ChromaImg2ImgPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds"])
batch_params = frozenset(["prompt"])
# there is no xformers processor for Flux
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
torch.manual_seed(0)
transformer = ChromaTransformer2DModel(
patch_size=1,
in_channels=4,
num_layers=num_layers,
num_single_layers=num_single_layers,
attention_head_dim=16,
num_attention_heads=2,
joint_attention_dim=32,
axes_dims_rope=[4, 4, 8],
approximator_hidden_dim=32,
approximator_layers=1,
approximator_num_channels=16,
)
torch.manual_seed(0)
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
vae = AutoencoderKL(
sample_size=32,
in_channels=3,
out_channels=3,
block_out_channels=(4,),
layers_per_block=1,
latent_channels=1,
norm_num_groups=1,
use_quant_conv=False,
use_post_quant_conv=False,
shift_factor=0.0609,
scaling_factor=1.5035,
)
scheduler = FlowMatchEulerDiscreteScheduler()
return {
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"transformer": transformer,
"vae": vae,
"image_encoder": None,
"feature_extractor": None,
}
def get_dummy_inputs(self, device, seed=0):
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"image": image,
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"height": 8,
"width": 8,
"max_sequence_length": 48,
"strength": 0.8,
"output_type": "np",
}
return inputs
def test_chroma_different_prompts(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_same_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt"] = "a different prompt"
output_different_prompts = pipe(**inputs).images[0]
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
# Outputs should be different here
# For some reasons, they don't show large differences
assert max_diff > 1e-6
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
original_image_slice = image[0, -3:, -3:, -1]
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
assert check_qkv_fusion_processors_exist(pipe.transformer), (
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
)
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice_fused = image[0, -3:, -3:, -1]
pipe.transformer.unfuse_qkv_projections()
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
"Fusion of QKV projections shouldn't affect the outputs."
)
assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
"Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
)
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
"Original outputs should match when fused QKV projections are disabled."
)
def test_chroma_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
height_width_pairs = [(32, 32), (72, 57)]
for height, width in height_width_pairs:
expected_height = height - height % (pipe.vae_scale_factor * 2)
expected_width = width - width % (pipe.vae_scale_factor * 2)
inputs.update({"height": height, "width": width})
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)