Compare commits

..

117 Commits

Author SHA1 Message Date
DN6 d267bb6955 update 2025-06-14 01:20:39 +05:30
DN6 e10f701537 update 2025-06-14 00:51:24 +05:30
DN6 0497faa3db update 2025-06-14 00:33:39 +05:30
DN6 4f00bae5de update 2025-06-14 00:31:33 +05:30
DN6 a967e66d03 update 2025-06-14 00:28:56 +05:30
DN6 2b559e9b79 Merge branch 'chroma-fork' into chroma-final 2025-06-14 00:27:54 +05:30
DN6 589e939e33 Revert "fix equal size list input"
This reverts commit 3fe4ad67d5.
2025-06-14 00:17:17 +05:30
BuildTools c711e8f10b fix equal size list input 2025-06-14 00:17:17 +05:30
BuildTools 0978b609c8 fix tests 2025-06-14 00:17:17 +05:30
BuildTools 4e24f26d6f default proj dim 2025-06-14 00:17:17 +05:30
BuildTools 8694f2ce53 add encoder test, remove pooled dim 2025-06-14 00:17:17 +05:30
BuildTools fd3e94450a push local changes, fix docs 2025-06-14 00:17:17 +05:30
Dhruv Nair 41751a3ec0 update 2025-06-13 20:41:49 +02:00
BuildTools 3fe4ad67d5 fix equal size list input 2025-06-13 10:51:31 -06:00
BuildTools 49a4c8bc22 fix tests 2025-06-13 09:41:44 -06:00
BuildTools 06fb9957a7 default proj dim 2025-06-13 08:38:02 -06:00
BuildTools 16b6e33916 add encoder test, remove pooled dim 2025-06-13 08:11:12 -06:00
BuildTools 178c4ec928 push local changes, fix docs 2025-06-13 07:46:29 -06:00
DN6 292469d755 update 2025-06-13 18:43:26 +05:30
DN6 bf56c953b8 Merge branch 'chroma-fork' into chroma-final 2025-06-13 18:41:56 +05:30
BuildTools b85229e262 Fix all pipeline test 2025-06-13 07:07:05 -06:00
DN6 f1be3ebc98 Merge branch 'chroma-fork' into chroma-final 2025-06-13 18:30:13 +05:30
Dhruv Nair 6735507705 fix for tests 2025-06-13 14:51:42 +02:00
BuildTools de9a07fc20 fix test skipping again 2025-06-13 05:47:41 -06:00
BuildTools 2b6722ecea fix test skipping 2025-06-13 05:35:58 -06:00
BuildTools 00ebba9725 skip batch tests 2025-06-13 05:25:11 -06:00
BuildTools bea8b0d86e make style, make quality 2025-06-13 04:54:33 -06:00
BuildTools 28dea06b3d fix docs 2025-06-13 04:53:30 -06:00
BuildTools 60e41b7835 Merge remote-tracking branch 'origin/chroma' into chroma 2025-06-13 04:43:48 -06:00
BuildTools 876649336e Make most transformer tests work 2025-06-13 04:43:31 -06:00
Edna 272685c0e5 Merge branch 'main' into chroma 2025-06-13 04:38:42 -06:00
BuildTools 829c6f199e Make more pipeline tests work 2025-06-13 04:38:13 -06:00
Dhruv Nair 89faa71f04 fix batch inference 2025-06-13 12:31:11 +02:00
DN6 926dcc6319 update to pad tokens 2025-06-13 13:43:17 +05:30
DN6 74fe45e823 update chroma transformer approximator init params 2025-06-13 13:36:39 +05:30
DN6 35dc65b7da update chroma transformer params 2025-06-13 13:30:04 +05:30
DN6 f35ec17a83 Merge remote-tracking branch '11698/chroma' into chroma-final 2025-06-13 11:28:57 +05:30
BuildTools 381e64b966 revert style fix 2025-06-12 22:22:39 -06:00
BuildTools c330f08fa2 make fix-copes 2025-06-12 21:53:55 -06:00
BuildTools 523150fb2c fix import 2025-06-12 21:47:35 -06:00
BuildTools 2bc51c8387 try to fix import 2025-06-12 21:36:09 -06:00
BuildTools fd36924620 remove # Copied from on protected members 2025-06-12 21:20:32 -06:00
Edna e97a4dd0c7 fix # Copied from error 2025-06-12 21:13:12 -06:00
Edna ad01d636be Merge branch 'main' into chroma 2025-06-12 21:06:55 -06:00
BuildTools 68b9cce897 switch to new input ids 2025-06-12 21:06:43 -06:00
github-actions[bot] f49b149c1c Apply style fixes 2025-06-13 02:02:25 +00:00
DN6 19733af2fc make style 2025-06-13 07:22:45 +05:30
BuildTools c85e46bd42 Fix auto pipeline + make style, quality 2025-06-12 10:31:02 -06:00
BuildTools d31cf81566 Move Approximator and Embeddings 2025-06-12 10:20:27 -06:00
Edna 2347d53f90 Update src/diffusers/models/transformers/transformer_chroma.py
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2025-06-12 10:12:27 -06:00
Edna cfd5b34051 fix chroma pipeline fast tests 2025-06-12 03:49:39 -06:00
Edna c8d6aef936 chroma init 2025-06-12 03:47:24 -06:00
Edna f8d4a1a774 move chroma test (oops) 2025-06-12 03:46:28 -06:00
Edna 15ca813e3e Add transformer tests 2025-06-12 03:45:43 -06:00
Edna 7235805e75 Revert cond + uncond batching 2025-06-12 03:40:52 -06:00
Edna abf8a33a96 update norm imports 2025-06-12 03:33:23 -06:00
Edna 6a0db55af8 Update # Copied from statements 2025-06-12 03:27:35 -06:00
Edna fe5af79a19 Add # Copied from for shift 2025-06-12 03:23:09 -06:00
Edna bedb32087a (untested) batch cond and uncond 2025-06-12 03:18:33 -06:00
Edna 03fbd520f4 Add chroma fast tests 2025-06-12 03:11:48 -06:00
Edna 1442c9789a Remove pruned AdaLayerNorms 2025-06-12 03:05:10 -06:00
Edna a1fac68a2d Move chroma layers into transformer 2025-06-12 03:04:41 -06:00
Edna 3e36a21c8e Update chroma.md 2025-06-12 02:58:21 -06:00
Edna a93e64d6fb Update docs/source/en/api/models/chroma_transformer.md
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2025-06-12 02:57:28 -06:00
Edna 3f39b1a730 do_cfg -> self.do_classifier_free_guidance 2025-06-12 02:56:24 -06:00
Edna 18327cb57c Update docs/source/en/api/pipelines/chroma.md
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2025-06-12 02:52:39 -06:00
Edna da846d1fff fix hf papers regression in more places 2025-06-12 00:53:40 -06:00
Edna 42c0e8ecbe undo arxiv change
unsure why that happened
2025-06-12 00:50:36 -06:00
Edna 0c5eb44701 undo don't change dtype 2025-06-12 00:46:41 -06:00
Edna b0cf6803a7 initial chroma docs 2025-06-11 22:07:21 -06:00
Edna f821f2ad5e add .md (oops) 2025-06-11 21:54:43 -06:00
Edna 619921ca22 add chroma autodoc 2025-06-11 21:53:27 -06:00
BuildTools 1efa772f69 remove unused stuff, fix up docs 2025-06-11 21:46:40 -06:00
Edna 3e2452ded0 dont change dtype 2025-06-11 21:23:35 -06:00
Edna 2d57f3dbac Merge branch 'main' into chroma 2025-06-11 21:20:24 -06:00
Edna 1bd8fdfcb6 don't return length 2025-06-11 20:56:27 -06:00
Edna 406ab3b1e9 remove guidance from embeddings 2025-06-11 20:47:59 -06:00
Edna e31c94866d remove guidance embed (pipeline) 2025-06-11 20:46:59 -06:00
Edna 01bc0dcc56 remove guidance 2025-06-11 20:45:45 -06:00
Edna e69d73099d use DN6 embeddings 2025-06-11 20:05:28 -06:00
Edna 442f77a2d7 use chroma pipeline output 2025-06-11 19:59:43 -06:00
Edna ab7942174a use dn6 attn mask + fix true_cfg_scale 2025-06-11 19:57:31 -06:00
Edna f6de1afc3f update 2025-06-11 19:54:27 -06:00
Edna f783f38883 ensure correct dtype for chroma embeddings 2025-06-11 19:52:43 -06:00
Edna a3b6697bc3 Merge branch 'main' into chroma 2025-06-11 19:48:02 -06:00
Edna 68f771bf43 take pooled projections out of transformer 2025-06-11 19:38:38 -06:00
Edna df7fde7a6d fix load 2025-06-11 19:36:34 -06:00
Edna 77b429eda4 change to my own unpooled embeddeer 2025-06-11 19:35:10 -06:00
Edna 3309ffef1c remove pooled prompt embeds 2025-06-11 19:33:17 -06:00
Edna 146255aba1 no attn mask (can't get it to work) 2025-06-11 19:17:29 -06:00
Edna c9b46af65f wrap attn mask 2025-06-11 19:16:24 -06:00
Edna 7c75d8e98d dont modify mask (for now) 2025-06-11 19:15:18 -06:00
Edna 38429ffcac remove mask function 2025-06-11 19:11:47 -06:00
Edna f190c02af7 work on swapping text encoders 2025-06-11 19:09:37 -06:00
Edna 6c0aed14db remove prompt_2 2025-06-11 19:06:45 -06:00
Edna 0b027a2453 swap embedder location 2025-06-11 19:04:52 -06:00
Edna 2fcc75a6d8 take out variant from blocks 2025-06-11 18:55:56 -06:00
Edna af918c89dd change to chroma transformer 2025-06-11 18:55:03 -06:00
Edna 7445cf422a add chroma to pipeline init 2025-06-11 18:53:06 -06:00
Edna a6f231c7ce add chroma to auto pipeline 2025-06-11 18:51:45 -06:00
Edna 6441e70def update 2025-06-11 18:48:44 -06:00
Edna f0c75b6b6f update 2025-06-11 18:46:51 -06:00
Edna 5eb4b822ae fix single file 2025-06-11 18:38:58 -06:00
Edna 4e698b1088 add chroma to init 2025-06-11 18:21:10 -06:00
Edna c22930d7cc add chroma to init 2025-06-11 18:18:56 -06:00
Edna 7400278857 add chroma transformer to dummy tp 2025-06-11 18:16:44 -06:00
Edna 32659236b2 make chroma output class 2025-06-10 02:24:23 -06:00
Edna c8cbb31614 add chroma init 2025-06-10 02:22:52 -06:00
Edna b0df9691d2 get decently far in changing variant stuff 2025-06-10 02:09:52 -06:00
Edna 22ecd19f91 take out variant stuff 2025-06-09 21:32:52 -06:00
Edna 33ea0b65a4 add chroma to transformer init 2025-06-09 21:25:19 -06:00
Edna bc36a0d883 add chroma to mappings 2025-06-09 21:15:19 -06:00
Edna 32e6a006cf add chroma loader 2025-06-09 21:13:32 -06:00
Edna 15f2bd5c39 working state (embeddings) 2025-06-09 21:05:59 -06:00
Edna e271af9495 working state (normalization) 2025-06-09 21:03:10 -06:00
Edna 3c2865c534 working state form hameerabbasi and iddl (transformer) 2025-06-09 21:02:12 -06:00
Edna ff0b9a3c4c working state from hameerabbasi and iddl 2025-06-09 20:59:00 -06:00
34 changed files with 173 additions and 4423 deletions
-2
View File
@@ -180,8 +180,6 @@
title: Caching
- local: optimization/memory
title: Reduce memory usage
- local: optimization/pruna
title: Pruna
- local: optimization/xformers
title: xFormers
- local: optimization/tome
-16
View File
@@ -36,22 +36,6 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
- all
- __call__
## Cosmos2TextToImagePipeline
[[autodoc]] Cosmos2TextToImagePipeline
- all
- __call__
## Cosmos2VideoToWorldPipeline
[[autodoc]] Cosmos2VideoToWorldPipeline
- all
- __call__
## CosmosPipelineOutput
[[autodoc]] pipelines.cosmos.pipeline_output.CosmosPipelineOutput
## CosmosImagePipelineOutput
[[autodoc]] pipelines.cosmos.pipeline_output.CosmosImagePipelineOutput
-187
View File
@@ -1,187 +0,0 @@
# Pruna
[Pruna](https://github.com/PrunaAI/pruna) is a model optimization framework that offers various optimization methods - quantization, pruning, caching, compilation - for accelerating inference and reducing memory usage. A general overview of the optimization methods are shown below.
| Technique | Description | Speed | Memory | Quality |
|--------------|-----------------------------------------------------------------------------------------------|:-----:|:------:|:-------:|
| `batcher` | Groups multiple inputs together to be processed simultaneously, improving computational efficiency and reducing processing time. | ✅ | ❌ | |
| `cacher` | Stores intermediate results of computations to speed up subsequent operations. | ✅ | | |
| `compiler` | Optimises the model with instructions for specific hardware. | ✅ | | |
| `distiller` | Trains a smaller, simpler model to mimic a larger, more complex model. | ✅ | ✅ | ❌ |
| `quantizer` | Reduces the precision of weights and activations, lowering memory requirements. | ✅ | ✅ | ❌ |
| `pruner` | Removes less important or redundant connections and neurons, resulting in a sparser, more efficient network. | ✅ | ✅ | ❌ |
| `recoverer` | Restores the performance of a model after compression. | ➖ | | ✅ |
| `factorizer` | Factorization batches several small matrix multiplications into one large fused operation. | ✅ | | |
| `enhancer` | Enhances the model output by applying post-processing algorithms such as denoising or upscaling. | ❌ | - | ✅ |
✅ (improves), (approx. the same), ❌ (worsens)
Explore the full range of optimization methods in the [Pruna documentation](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/configure.html#configure-algorithms).
## Installation
Install Pruna with the following command.
```bash
pip install pruna
```
## Optimize Diffusers models
A broad range of optimization algorithms are supported for Diffusers models as shown below.
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/PrunaAI/documentation-images/resolve/main/diffusers/diffusers_combinations.png" alt="Overview of the supported optimization algorithms for diffusers models">
</div>
The example below optimizes [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)
with a combination of factorizer, compiler, and cacher algorithms. This combination accelerates inference by up to 4.2x and cuts peak GPU memory usage from 34.7GB to 28.0GB, all while maintaining virtually the same output quality.
> [!TIP]
> Refer to the [Pruna optimization](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/configure.html) docs to learn more about the optimization techniques used in this example.
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/PrunaAI/documentation-images/resolve/main/diffusers/flux_combination.png" alt="Optimization techniques used for FLUX.1-dev showing the combination of factorizer, compiler, and cacher algorithms">
</div>
Start by defining a `SmashConfig` with the optimization algorithms to use. To optimize the model, wrap the pipeline and the `SmashConfig` with `smash` and then use the pipeline as normal for inference.
```python
import torch
from diffusers import FluxPipeline
from pruna import PrunaModel, SmashConfig, smash
# load the model
# Try segmind/Segmind-Vega or black-forest-labs/FLUX.1-schnell with a small GPU memory
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16
).to("cuda")
# define the configuration
smash_config = SmashConfig()
smash_config["factorizer"] = "qkv_diffusers"
smash_config["compiler"] = "torch_compile"
smash_config["torch_compile_target"] = "module_list"
smash_config["cacher"] = "fora"
smash_config["fora_interval"] = 2
# for the best results in terms of speed you can add these configs
# however they will increase your warmup time from 1.5 min to 10 min
# smash_config["torch_compile_mode"] = "max-autotune-no-cudagraphs"
# smash_config["quantizer"] = "torchao"
# smash_config["torchao_quant_type"] = "fp8dq"
# smash_config["torchao_excluded_modules"] = "norm+embedding"
# optimize the model
smashed_pipe = smash(pipe, smash_config)
# run the model
smashed_pipe("a knitted purple prune").images[0]
```
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/PrunaAI/documentation-images/resolve/main/diffusers/flux_smashed_comparison.png">
</div>
After optimization, we can share and load the optimized model using the Hugging Face Hub.
```python
# save the model
smashed_pipe.save_to_hub("<username>/FLUX.1-dev-smashed")
# load the model
smashed_pipe = PrunaModel.from_hub("<username>/FLUX.1-dev-smashed")
```
## Evaluate and benchmark Diffusers models
Pruna provides the [EvaluationAgent](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/evaluate.html) to evaluate the quality of your optimized models.
We can metrics we care about, such as total time and throughput, and the dataset to evaluate on. We can define a model and pass it to the `EvaluationAgent`.
<hfoptions id="eval">
<hfoption id="optimized model">
We can load and evaluate an optimized model by using the `EvaluationAgent` and pass it to the `Task`.
```python
import torch
from diffusers import FluxPipeline
from pruna import PrunaModel
from pruna.data.pruna_datamodule import PrunaDataModule
from pruna.evaluation.evaluation_agent import EvaluationAgent
from pruna.evaluation.metrics import (
ThroughputMetric,
TorchMetricWrapper,
TotalTimeMetric,
)
from pruna.evaluation.task import Task
# define the device
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
# load the model
# Try PrunaAI/Segmind-Vega-smashed or PrunaAI/FLUX.1-dev-smashed with a small GPU memory
smashed_pipe = PrunaModel.from_hub("PrunaAI/FLUX.1-dev-smashed")
# Define the metrics
metrics = [
TotalTimeMetric(n_iterations=20, n_warmup_iterations=5),
ThroughputMetric(n_iterations=20, n_warmup_iterations=5),
TorchMetricWrapper("clip"),
]
# Define the datamodule
datamodule = PrunaDataModule.from_string("LAION256")
datamodule.limit_datasets(10)
# Define the task and evaluation agent
task = Task(metrics, datamodule=datamodule, device=device)
eval_agent = EvaluationAgent(task)
# Evaluate smashed model and offload it to CPU
smashed_pipe.move_to_device(device)
smashed_pipe_results = eval_agent.evaluate(smashed_pipe)
smashed_pipe.move_to_device("cpu")
```
</hfoption>
<hfoption id="standalone model">
Instead of comparing the optimized model to the base model, you can also evaluate the standalone `diffusers` model. This is useful if you want to evaluate the performance of the model without the optimization. We can do so by using the `PrunaModel` wrapper and run the `EvaluationAgent` on it.
```python
import torch
from diffusers import FluxPipeline
from pruna import PrunaModel
# load the model
# Try PrunaAI/Segmind-Vega-smashed or PrunaAI/FLUX.1-dev-smashed with a small GPU memory
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16
).to("cpu")
wrapped_pipe = PrunaModel(model=pipe)
```
</hfoption>
</hfoptions>
Now that you have seen how to optimize and evaluate your models, you can start using Pruna to optimize your own models. Luckily, we have many examples to help you get started.
> [!TIP]
> For more details about benchmarking Flux, check out the [Announcing FLUX-Juiced: The Fastest Image Generation Endpoint (2.6 times faster)!](https://huggingface.co/blog/PrunaAI/flux-fastest-image-generation-endpoint) blog post and the [InferBench](https://huggingface.co/spaces/PrunaAI/InferBench) Space.
## Reference
- [Pruna](https://github.com/pruna-ai/pruna)
- [Pruna optimization](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/configure.html#configure-algorithms)
- [Pruna evaluation](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/evaluate.html)
- [Pruna tutorials](https://docs.pruna.ai/en/stable/docs_pruna/tutorials/index.html)
@@ -76,24 +76,6 @@ This command will prompt you for a token. Copy-paste yours from your [settings/t
> `pip install wandb`
> Alternatively, you can use other tools / train without reporting by modifying the flag `--report_to="wandb"`.
### LoRA Rank and Alpha
Two key LoRA hyperparameters are LoRA rank and LoRA alpha.
- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters).
- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by lora_alpha / lora_rank.
- lora_alpha vs. rank:
This ratio dictates the LoRA's effective strength:
lora_alpha == rank: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16)
lora_alpha < rank: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16)
lora_alpha > rank: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16)
> [!TIP]
> A common starting point is to set `lora_alpha` equal to `rank`.
> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16)
> to give the LoRA updates more influence without increasing parameter count.
> If you find your LoRA is "overcooking" or learning too aggressively, consider setting `lora_alpha` to half of `rank`
> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case.
### Target Modules
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import os
import sys
@@ -21,8 +20,6 @@ import tempfile
import safetensors
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
@@ -284,45 +281,3 @@ class DreamBoothLoRAFluxAdvanced(ExamplesTestsAccelerate):
run_command(self._launch_args + resume_run_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
def test_dreambooth_lora_with_metadata(self):
# Use a `lora_alpha` that is different from `rank`.
lora_alpha = 8
rank = 4
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--lora_alpha={lora_alpha}
--rank={rank}
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
self.assertTrue(os.path.isfile(state_dict_file))
# Check if the metadata was properly serialized.
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
metadata = f.metadata() or {}
metadata.pop("format", None)
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
if raw:
raw = json.loads(raw)
loaded_lora_alpha = raw["transformer.lora_alpha"]
self.assertTrue(loaded_lora_alpha == lora_alpha)
loaded_lora_rank = raw["transformer.r"]
self.assertTrue(loaded_lora_rank == rank)
@@ -55,7 +55,6 @@ from diffusers import (
)
from diffusers.optimization import get_scheduler
from diffusers.training_utils import (
_collate_lora_metadata,
_set_state_dict_into_text_encoder,
cast_training_params,
compute_density_for_timestep_sampling,
@@ -432,13 +431,6 @@ def parse_args(input_args=None):
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument(
"--lora_alpha",
type=int,
default=4,
help="LoRA alpha to be used for additional scaling.",
)
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
parser.add_argument(
@@ -1564,7 +1556,7 @@ def main(args):
# now we will add new LoRA weights to the attention layers
transformer_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.lora_alpha,
lora_alpha=args.rank,
lora_dropout=args.lora_dropout,
init_lora_weights="gaussian",
target_modules=target_modules,
@@ -1573,7 +1565,7 @@ def main(args):
if args.train_text_encoder:
text_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.lora_alpha,
lora_alpha=args.rank,
lora_dropout=args.lora_dropout,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
@@ -1590,15 +1582,13 @@ def main(args):
if accelerator.is_main_process:
transformer_lora_layers_to_save = None
text_encoder_one_lora_layers_to_save = None
modules_to_save = {}
for model in models:
if isinstance(model, type(unwrap_model(transformer))):
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
modules_to_save["transformer"] = model
elif isinstance(model, type(unwrap_model(text_encoder_one))):
if args.train_text_encoder: # when --train_text_encoder_ti we don't save the layers
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
modules_to_save["text_encoder"] = model
elif isinstance(model, type(unwrap_model(text_encoder_two))):
pass # when --train_text_encoder_ti and --enable_t5_ti we don't save the layers
else:
@@ -1611,7 +1601,6 @@ def main(args):
output_dir,
transformer_lora_layers=transformer_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
**_collate_lora_metadata(modules_to_save),
)
if args.train_text_encoder_ti:
embedding_handler.save_embeddings(f"{args.output_dir}/{Path(args.output_dir).name}_emb.safetensors")
@@ -2370,19 +2359,16 @@ def main(args):
# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
modules_to_save = {}
transformer = unwrap_model(transformer)
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)
modules_to_save["transformer"] = transformer
if args.train_text_encoder:
text_encoder_one = unwrap_model(text_encoder_one)
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
modules_to_save["text_encoder"] = text_encoder_one
else:
text_encoder_lora_layers = None
@@ -2391,7 +2377,6 @@ def main(args):
save_directory=args.output_dir,
transformer_lora_layers=transformer_lora_layers,
text_encoder_lora_layers=text_encoder_lora_layers,
**_collate_lora_metadata(modules_to_save),
)
if args.train_text_encoder_ti:
-17
View File
@@ -170,23 +170,6 @@ accelerate launch train_dreambooth_lora_flux.py \
--push_to_hub
```
### LoRA Rank and Alpha
Two key LoRA hyperparameters are LoRA rank and LoRA alpha.
- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters).
- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by lora_alpha / lora_rank.
- lora_alpha vs. rank:
This ratio dictates the LoRA's effective strength:
lora_alpha == rank: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16)
lora_alpha < rank: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16)
lora_alpha > rank: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16)
> [!TIP]
> A common starting point is to set `lora_alpha` equal to `rank`.
> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16)
> to give the LoRA updates more influence without increasing parameter count.
> If you find your LoRA is "overcooking" or learning too aggressively, consider setting `lora_alpha` to half of `rank`
> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case.
### Target Modules
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import os
import sys
@@ -21,8 +20,6 @@ import tempfile
import safetensors
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
@@ -237,45 +234,3 @@ class DreamBoothLoRAFlux(ExamplesTestsAccelerate):
run_command(self._launch_args + resume_run_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
def test_dreambooth_lora_with_metadata(self):
# Use a `lora_alpha` that is different from `rank`.
lora_alpha = 8
rank = 4
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--lora_alpha={lora_alpha}
--rank={rank}
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
self.assertTrue(os.path.isfile(state_dict_file))
# Check if the metadata was properly serialized.
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
metadata = f.metadata() or {}
metadata.pop("format", None)
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
if raw:
raw = json.loads(raw)
loaded_lora_alpha = raw["transformer.lora_alpha"]
self.assertTrue(loaded_lora_alpha == lora_alpha)
loaded_lora_rank = raw["transformer.r"]
self.assertTrue(loaded_lora_rank == rank)
@@ -27,6 +27,7 @@ from pathlib import Path
import numpy as np
import torch
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
@@ -52,7 +53,6 @@ from diffusers import (
)
from diffusers.optimization import get_scheduler
from diffusers.training_utils import (
_collate_lora_metadata,
_set_state_dict_into_text_encoder,
cast_training_params,
compute_density_for_timestep_sampling,
@@ -358,12 +358,7 @@ def parse_args(input_args=None):
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument(
"--lora_alpha",
type=int,
default=4,
help="LoRA alpha to be used for additional scaling.",
)
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
parser.add_argument(
@@ -1243,7 +1238,7 @@ def main(args):
# now we will add new LoRA weights the transformer layers
transformer_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.lora_alpha,
lora_alpha=args.rank,
lora_dropout=args.lora_dropout,
init_lora_weights="gaussian",
target_modules=target_modules,
@@ -1252,7 +1247,7 @@ def main(args):
if args.train_text_encoder:
text_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.lora_alpha,
lora_alpha=args.rank,
lora_dropout=args.lora_dropout,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
@@ -1269,14 +1264,12 @@ def main(args):
if accelerator.is_main_process:
transformer_lora_layers_to_save = None
text_encoder_one_lora_layers_to_save = None
modules_to_save = {}
for model in models:
if isinstance(model, type(unwrap_model(transformer))):
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
modules_to_save["transformer"] = model
elif isinstance(model, type(unwrap_model(text_encoder_one))):
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
modules_to_save["text_encoder"] = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
@@ -1287,7 +1280,6 @@ def main(args):
output_dir,
transformer_lora_layers=transformer_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
**_collate_lora_metadata(modules_to_save),
)
def load_model_hook(models, input_dir):
@@ -1897,19 +1889,16 @@ def main(args):
# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
modules_to_save = {}
transformer = unwrap_model(transformer)
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)
modules_to_save["transformer"] = transformer
if args.train_text_encoder:
text_encoder_one = unwrap_model(text_encoder_one)
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
modules_to_save["text_encoder"] = text_encoder_one
else:
text_encoder_lora_layers = None
@@ -1917,7 +1906,6 @@ def main(args):
save_directory=args.output_dir,
transformer_lora_layers=transformer_lora_layers,
text_encoder_lora_layers=text_encoder_lora_layers,
**_collate_lora_metadata(modules_to_save),
)
# Final inference
@@ -29,7 +29,7 @@ from pathlib import Path
import numpy as np
import torch
import transformers
from accelerate import Accelerator, DistributedType
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
@@ -1181,15 +1181,13 @@ def main(args):
transformer_lora_layers_to_save = None
for model in models:
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
model = unwrap_model(model)
if isinstance(model, type(unwrap_model(transformer))):
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again
if weights:
weights.pop()
weights.pop()
HiDreamImagePipeline.save_lora_weights(
output_dir,
@@ -1199,20 +1197,13 @@ def main(args):
def load_model_hook(models, input_dir):
transformer_ = None
if not accelerator.distributed_type == DistributedType.DEEPSPEED:
while len(models) > 0:
model = models.pop()
while len(models) > 0:
model = models.pop()
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
model = unwrap_model(model)
transformer_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
else:
transformer_ = HiDreamImageTransformer2DModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="transformer"
)
transformer_.add_adapter(transformer_lora_config)
if isinstance(model, type(unwrap_model(transformer))):
transformer_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict = HiDreamImagePipeline.lora_state_dict(input_dir)
@@ -1664,7 +1655,7 @@ def main(args):
progress_bar.update(1)
global_step += 1
if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
+31 -186
View File
@@ -7,17 +7,7 @@ from accelerate import init_empty_weights
from huggingface_hub import snapshot_download
from transformers import T5EncoderModel, T5TokenizerFast
from diffusers import (
AutoencoderKLCosmos,
AutoencoderKLWan,
Cosmos2TextToImagePipeline,
Cosmos2VideoToWorldPipeline,
CosmosTextToWorldPipeline,
CosmosTransformer3DModel,
CosmosVideoToWorldPipeline,
EDMEulerScheduler,
FlowMatchEulerDiscreteScheduler,
)
from diffusers import AutoencoderKLCosmos, CosmosTextToWorldPipeline, CosmosTransformer3DModel, EDMEulerScheduler
def remove_keys_(key: str, state_dict: Dict[str, Any]):
@@ -39,7 +29,7 @@ def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
state_dict[new_key] = state_dict.pop(key)
TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0 = {
TRANSFORMER_KEYS_RENAME_DICT = {
"t_embedder.1": "time_embed.t_embedder",
"affline_norm": "time_embed.norm",
".blocks.0.block.attn": ".attn1",
@@ -66,7 +56,7 @@ TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0 = {
"final_layer.linear": "proj_out",
}
TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0 = {
TRANSFORMER_SPECIAL_KEYS_REMAP = {
"blocks.block": rename_transformer_blocks_,
"logvar.0.freqs": remove_keys_,
"logvar.0.phases": remove_keys_,
@@ -74,45 +64,6 @@ TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0 = {
"pos_embedder.seq": remove_keys_,
}
TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 = {
"t_embedder.1": "time_embed.t_embedder",
"t_embedding_norm": "time_embed.norm",
"blocks": "transformer_blocks",
"adaln_modulation_self_attn.1": "norm1.linear_1",
"adaln_modulation_self_attn.2": "norm1.linear_2",
"adaln_modulation_cross_attn.1": "norm2.linear_1",
"adaln_modulation_cross_attn.2": "norm2.linear_2",
"adaln_modulation_mlp.1": "norm3.linear_1",
"adaln_modulation_mlp.2": "norm3.linear_2",
"self_attn": "attn1",
"cross_attn": "attn2",
"q_proj": "to_q",
"k_proj": "to_k",
"v_proj": "to_v",
"output_proj": "to_out.0",
"q_norm": "norm_q",
"k_norm": "norm_k",
"mlp.layer1": "ff.net.0.proj",
"mlp.layer2": "ff.net.2",
"x_embedder.proj.1": "patch_embed.proj",
# "extra_pos_embedder": "learnable_pos_embed",
"final_layer.adaln_modulation.1": "norm_out.linear_1",
"final_layer.adaln_modulation.2": "norm_out.linear_2",
"final_layer.linear": "proj_out",
}
TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0 = {
"accum_video_sample_counter": remove_keys_,
"accum_image_sample_counter": remove_keys_,
"accum_iteration": remove_keys_,
"accum_train_in_hours": remove_keys_,
"pos_embedder.seq": remove_keys_,
"pos_embedder.dim_spatial_range": remove_keys_,
"pos_embedder.dim_temporal_range": remove_keys_,
"_extra_state": remove_keys_,
}
TRANSFORMER_CONFIGS = {
"Cosmos-1.0-Diffusion-7B-Text2World": {
"in_channels": 16,
@@ -174,66 +125,6 @@ TRANSFORMER_CONFIGS = {
"concat_padding_mask": True,
"extra_pos_embed_type": "learnable",
},
"Cosmos-2.0-Diffusion-2B-Text2Image": {
"in_channels": 16,
"out_channels": 16,
"num_attention_heads": 16,
"attention_head_dim": 128,
"num_layers": 28,
"mlp_ratio": 4.0,
"text_embed_dim": 1024,
"adaln_lora_dim": 256,
"max_size": (128, 240, 240),
"patch_size": (1, 2, 2),
"rope_scale": (1.0, 4.0, 4.0),
"concat_padding_mask": True,
"extra_pos_embed_type": None,
},
"Cosmos-2.0-Diffusion-14B-Text2Image": {
"in_channels": 16,
"out_channels": 16,
"num_attention_heads": 40,
"attention_head_dim": 128,
"num_layers": 36,
"mlp_ratio": 4.0,
"text_embed_dim": 1024,
"adaln_lora_dim": 256,
"max_size": (128, 240, 240),
"patch_size": (1, 2, 2),
"rope_scale": (1.0, 4.0, 4.0),
"concat_padding_mask": True,
"extra_pos_embed_type": None,
},
"Cosmos-2.0-Diffusion-2B-Video2World": {
"in_channels": 16 + 1,
"out_channels": 16,
"num_attention_heads": 16,
"attention_head_dim": 128,
"num_layers": 28,
"mlp_ratio": 4.0,
"text_embed_dim": 1024,
"adaln_lora_dim": 256,
"max_size": (128, 240, 240),
"patch_size": (1, 2, 2),
"rope_scale": (1.0, 3.0, 3.0),
"concat_padding_mask": True,
"extra_pos_embed_type": None,
},
"Cosmos-2.0-Diffusion-14B-Video2World": {
"in_channels": 16 + 1,
"out_channels": 16,
"num_attention_heads": 40,
"attention_head_dim": 128,
"num_layers": 36,
"mlp_ratio": 4.0,
"text_embed_dim": 1024,
"adaln_lora_dim": 256,
"max_size": (128, 240, 240),
"patch_size": (1, 2, 2),
"rope_scale": (20 / 24, 2.0, 2.0),
"concat_padding_mask": True,
"extra_pos_embed_type": None,
},
}
VAE_KEYS_RENAME_DICT = {
@@ -325,18 +216,9 @@ def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
return state_dict
def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: bool = True):
def convert_transformer(transformer_type: str, ckpt_path: str):
PREFIX_KEY = "net."
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=weights_only))
if "Cosmos-1.0" in transformer_type:
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0
elif "Cosmos-2.0" in transformer_type:
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
else:
assert False
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True))
with init_empty_weights():
config = TRANSFORMER_CONFIGS[transformer_type]
@@ -399,61 +281,13 @@ def convert_vae(vae_type: str):
return vae
def save_pipeline_cosmos_1_0(args, transformer, vae):
text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.bfloat16)
tokenizer = T5TokenizerFast.from_pretrained(args.tokenizer_path)
# The original code initializes EDM config with sigma_min=0.0002, but does not make use of it anywhere directly.
# So, the sigma_min values that is used is the default value of 0.002.
scheduler = EDMEulerScheduler(
sigma_min=0.002,
sigma_max=80,
sigma_data=0.5,
sigma_schedule="karras",
num_train_timesteps=1000,
prediction_type="epsilon",
rho=7.0,
final_sigmas_type="sigma_min",
)
pipe_cls = CosmosTextToWorldPipeline if "Text2World" in args.transformer_type else CosmosVideoToWorldPipeline
pipe = pipe_cls(
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
vae=vae,
scheduler=scheduler,
safety_checker=lambda *args, **kwargs: None,
)
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
def save_pipeline_cosmos_2_0(args, transformer, vae):
text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.bfloat16)
tokenizer = T5TokenizerFast.from_pretrained(args.tokenizer_path)
scheduler = FlowMatchEulerDiscreteScheduler(use_karras_sigmas=True)
pipe_cls = Cosmos2TextToImagePipeline if "Text2Image" in args.transformer_type else Cosmos2VideoToWorldPipeline
pipe = pipe_cls(
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
vae=vae,
scheduler=scheduler,
safety_checker=lambda *args, **kwargs: None,
)
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--transformer_type", type=str, default=None, choices=list(TRANSFORMER_CONFIGS.keys()))
parser.add_argument(
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
)
parser.add_argument(
"--vae_type", type=str, default=None, choices=["none", *list(VAE_CONFIGS.keys())], help="Type of VAE"
)
parser.add_argument("--vae_type", type=str, default=None, choices=list(VAE_CONFIGS.keys()), help="Type of VAE")
parser.add_argument("--text_encoder_path", type=str, default="google-t5/t5-11b")
parser.add_argument("--tokenizer_path", type=str, default="google-t5/t5-11b")
parser.add_argument("--save_pipeline", action="store_true")
@@ -482,26 +316,37 @@ if __name__ == "__main__":
assert args.tokenizer_path is not None
if args.transformer_ckpt_path is not None:
weights_only = "Cosmos-1.0" in args.transformer_type
transformer = convert_transformer(args.transformer_type, args.transformer_ckpt_path, weights_only)
transformer = convert_transformer(args.transformer_type, args.transformer_ckpt_path)
transformer = transformer.to(dtype=dtype)
if not args.save_pipeline:
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
if args.vae_type is not None:
if "Cosmos-1.0" in args.transformer_type:
vae = convert_vae(args.vae_type)
else:
vae = AutoencoderKLWan.from_pretrained(
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
)
vae = convert_vae(args.vae_type)
if not args.save_pipeline:
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
if args.save_pipeline:
if "Cosmos-1.0" in args.transformer_type:
save_pipeline_cosmos_1_0(args, transformer, vae)
elif "Cosmos-2.0" in args.transformer_type:
save_pipeline_cosmos_2_0(args, transformer, vae)
else:
assert False
text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_path, torch_dtype=dtype)
tokenizer = T5TokenizerFast.from_pretrained(args.tokenizer_path)
# The original code initializes EDM config with sigma_min=0.0002, but does not make use of it anywhere directly.
# So, the sigma_min values that is used is the default value of 0.002.
scheduler = EDMEulerScheduler(
sigma_min=0.002,
sigma_max=80,
sigma_data=0.5,
sigma_schedule="karras",
num_train_timesteps=1000,
prediction_type="epsilon",
rho=7.0,
final_sigmas_type="sigma_min",
)
pipe = CosmosTextToWorldPipeline(
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
vae=vae,
scheduler=scheduler,
)
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
+1 -1
View File
@@ -1,4 +1,4 @@
# Run this script to convert the Stable Audio model weights to a diffusers pipeline.
# Run this script to convert the Stable Cascade model weights to a diffusers pipeline.
import argparse
import json
import os
-4
View File
@@ -363,8 +363,6 @@ else:
"CogView4ControlPipeline",
"CogView4Pipeline",
"ConsisIDPipeline",
"Cosmos2TextToImagePipeline",
"Cosmos2VideoToWorldPipeline",
"CosmosTextToWorldPipeline",
"CosmosVideoToWorldPipeline",
"CycleDiffusionPipeline",
@@ -955,8 +953,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
CogView4ControlPipeline,
CogView4Pipeline,
ConsisIDPipeline,
Cosmos2TextToImagePipeline,
Cosmos2VideoToWorldPipeline,
CosmosTextToWorldPipeline,
CosmosVideoToWorldPipeline,
CycleDiffusionPipeline,
+36 -51
View File
@@ -1596,10 +1596,7 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
converted_state_dict = {}
original_state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()}
block_numbers = {int(k.split(".")[1]) for k in original_state_dict if k.startswith("blocks.")}
min_block = min(block_numbers)
max_block = max(block_numbers)
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in original_state_dict if "blocks." in k})
is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
lora_down_key = "lora_A" if any("lora_A" in k for k in original_state_dict) else "lora_down"
lora_up_key = "lora_B" if any("lora_B" in k for k in original_state_dict) else "lora_up"
@@ -1625,57 +1622,45 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
# For the `diff_b` keys, we treat them as lora_bias.
# https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.lora_bias
for i in range(min_block, max_block + 1):
for i in range(num_blocks):
# Self-attention
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
original_key = f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight"
converted_key = f"blocks.{i}.attn1.{c}.lora_A.weight"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
original_key = f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight"
converted_key = f"blocks.{i}.attn1.{c}.lora_B.weight"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
original_key = f"blocks.{i}.self_attn.{o}.diff_b"
converted_key = f"blocks.{i}.attn1.{c}.lora_B.bias"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_A.weight"] = original_state_dict.pop(
f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight"
)
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = original_state_dict.pop(
f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight"
)
if f"blocks.{i}.self_attn.{o}.diff_b" in original_state_dict:
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.bias"] = original_state_dict.pop(
f"blocks.{i}.self_attn.{o}.diff_b"
)
# Cross-attention
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
)
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
)
if f"blocks.{i}.cross_attn.{o}.diff_b" in original_state_dict:
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.bias"] = original_state_dict.pop(
f"blocks.{i}.cross_attn.{o}.diff_b"
)
if is_i2v_lora:
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
)
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
)
if f"blocks.{i}.cross_attn.{o}.diff_b" in original_state_dict:
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.bias"] = original_state_dict.pop(
f"blocks.{i}.cross_attn.{o}.diff_b"
)
# FFN
for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
@@ -1689,10 +1674,10 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
original_key = f"blocks.{i}.{o}.diff_b"
converted_key = f"blocks.{i}.ffn.{c}.lora_B.bias"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
if f"blocks.{i}.{o}.diff_b" in original_state_dict:
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.bias"] = original_state_dict.pop(
f"blocks.{i}.{o}.diff_b"
)
# Remaining.
if original_state_dict:
+9 -37
View File
@@ -2031,36 +2031,18 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
if is_kohya:
state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict)
# Kohya already takes care of scaling the LoRA parameters with alpha.
return cls._prepare_outputs(
state_dict,
metadata=metadata,
alphas=None,
return_alphas=return_alphas,
return_metadata=return_lora_metadata,
)
return (state_dict, None) if return_alphas else state_dict
is_xlabs = any("processor" in k for k in state_dict)
if is_xlabs:
state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict)
# xlabs doesn't use `alpha`.
return cls._prepare_outputs(
state_dict,
metadata=metadata,
alphas=None,
return_alphas=return_alphas,
return_metadata=return_lora_metadata,
)
return (state_dict, None) if return_alphas else state_dict
is_bfl_control = any("query_norm.scale" in k for k in state_dict)
if is_bfl_control:
state_dict = _convert_bfl_flux_control_lora_to_diffusers(state_dict)
return cls._prepare_outputs(
state_dict,
metadata=metadata,
alphas=None,
return_alphas=return_alphas,
return_metadata=return_lora_metadata,
)
return (state_dict, None) if return_alphas else state_dict
# For state dicts like
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
@@ -2079,13 +2061,12 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
)
if return_alphas or return_lora_metadata:
return cls._prepare_outputs(
state_dict,
metadata=metadata,
alphas=network_alphas,
return_alphas=return_alphas,
return_metadata=return_lora_metadata,
)
outputs = [state_dict]
if return_alphas:
outputs.append(network_alphas)
if return_lora_metadata:
outputs.append(metadata)
return tuple(outputs)
else:
return state_dict
@@ -2804,15 +2785,6 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.")
@staticmethod
def _prepare_outputs(state_dict, metadata, alphas=None, return_alphas=False, return_metadata=False):
outputs = [state_dict]
if return_alphas:
outputs.append(alphas)
if return_metadata:
outputs.append(metadata)
return tuple(outputs) if (return_alphas or return_metadata) else state_dict
# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.
+1 -3
View File
@@ -187,9 +187,7 @@ class PeftAdapterMixin:
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
metadata:
LoRA adapter metadata. When supplied, the metadata inferred through the state dict isn't used to
initialize `LoraConfig`.
metadata: TODO
"""
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
from peft.tuners.tuners_utils import BaseTunerLayer
@@ -749,16 +749,6 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.tile_sample_stride_height = 192
self.tile_sample_stride_width = 192
# Precompute and cache conv counts for encoder and decoder for clear_cache speedup
self._cached_conv_counts = {
"decoder": sum(isinstance(m, WanCausalConv3d) for m in self.decoder.modules())
if self.decoder is not None
else 0,
"encoder": sum(isinstance(m, WanCausalConv3d) for m in self.encoder.modules())
if self.encoder is not None
else 0,
}
def enable_tiling(
self,
tile_sample_min_height: Optional[int] = None,
@@ -811,12 +801,18 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.use_slicing = False
def clear_cache(self):
# Use cached conv counts for decoder and encoder to avoid re-iterating modules each call
self._conv_num = self._cached_conv_counts["decoder"]
def _count_conv3d(model):
count = 0
for m in model.modules():
if isinstance(m, WanCausalConv3d):
count += 1
return count
self._conv_num = _count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
# cache encode
self._enc_conv_num = self._cached_conv_counts["encoder"]
self._enc_conv_num = _count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num
@@ -100,15 +100,11 @@ class CosmosAdaLayerNorm(nn.Module):
embedded_timestep = self.linear_2(embedded_timestep)
if temb is not None:
embedded_timestep = embedded_timestep + temb[..., : 2 * self.embedding_dim]
embedded_timestep = embedded_timestep + temb[:, : 2 * self.embedding_dim]
shift, scale = embedded_timestep.chunk(2, dim=-1)
shift, scale = embedded_timestep.chunk(2, dim=1)
hidden_states = self.norm(hidden_states)
if embedded_timestep.ndim == 2:
shift, scale = (x.unsqueeze(1) for x in (shift, scale))
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = hidden_states * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
return hidden_states
@@ -139,13 +135,9 @@ class CosmosAdaLayerNormZero(nn.Module):
if temb is not None:
embedded_timestep = embedded_timestep + temb
shift, scale, gate = embedded_timestep.chunk(3, dim=-1)
shift, scale, gate = embedded_timestep.chunk(3, dim=1)
hidden_states = self.norm(hidden_states)
if embedded_timestep.ndim == 2:
shift, scale, gate = (x.unsqueeze(1) for x in (shift, scale, gate))
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = hidden_states * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
return hidden_states, gate
@@ -263,19 +255,19 @@ class CosmosTransformerBlock(nn.Module):
# 1. Self Attention
norm_hidden_states, gate = self.norm1(hidden_states, embedded_timestep, temb)
attn_output = self.attn1(norm_hidden_states, image_rotary_emb=image_rotary_emb)
hidden_states = hidden_states + gate * attn_output
hidden_states = hidden_states + gate.unsqueeze(1) * attn_output
# 2. Cross Attention
norm_hidden_states, gate = self.norm2(hidden_states, embedded_timestep, temb)
attn_output = self.attn2(
norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
)
hidden_states = hidden_states + gate * attn_output
hidden_states = hidden_states + gate.unsqueeze(1) * attn_output
# 3. Feed Forward
norm_hidden_states, gate = self.norm3(hidden_states, embedded_timestep, temb)
ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states + gate * ff_output
hidden_states = hidden_states + gate.unsqueeze(1) * ff_output
return hidden_states
@@ -521,23 +513,7 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin):
hidden_states = hidden_states.flatten(1, 3) # [B, T, H, W, C] -> [B, THW, C]
# 4. Timestep embeddings
if timestep.ndim == 1:
temb, embedded_timestep = self.time_embed(hidden_states, timestep)
elif timestep.ndim == 5:
assert timestep.shape == (batch_size, 1, num_frames, 1, 1), (
f"Expected timestep to have shape [B, 1, T, 1, 1], but got {timestep.shape}"
)
timestep = timestep.flatten()
temb, embedded_timestep = self.time_embed(hidden_states, timestep)
# We can do this because num_frames == post_patch_num_frames, as p_t is 1
temb, embedded_timestep = (
x.view(batch_size, post_patch_num_frames, 1, 1, -1)
.expand(-1, -1, post_patch_height, post_patch_width, -1)
.flatten(1, 3)
for x in (temb, embedded_timestep)
) # [BT, C] -> [B, T, 1, 1, C] -> [B, T, H, W, C] -> [B, THW, C]
else:
assert False
temb, embedded_timestep = self.time_embed(hidden_states, timestep)
# 5. Transformer blocks
for block in self.transformer_blocks:
@@ -568,8 +544,8 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin):
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.unflatten(2, (p_h, p_w, p_t, -1))
hidden_states = hidden_states.unflatten(1, (post_patch_num_frames, post_patch_height, post_patch_width))
# NOTE: The permutation order here is not the inverse operation of what happens when patching as usually expected.
# It might be a source of confusion to the reader, but this is correct
# Please just kill me at this point. What even is this permutation order and why is it different from the patching order?
# Another few hours of sanity lost to the void.
hidden_states = hidden_states.permute(0, 7, 1, 6, 2, 4, 3, 5)
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
+2 -12
View File
@@ -158,12 +158,7 @@ else:
_import_structure["cogview3"] = ["CogView3PlusPipeline"]
_import_structure["cogview4"] = ["CogView4Pipeline", "CogView4ControlPipeline"]
_import_structure["consisid"] = ["ConsisIDPipeline"]
_import_structure["cosmos"] = [
"Cosmos2TextToImagePipeline",
"CosmosTextToWorldPipeline",
"CosmosVideoToWorldPipeline",
"Cosmos2VideoToWorldPipeline",
]
_import_structure["cosmos"] = ["CosmosTextToWorldPipeline", "CosmosVideoToWorldPipeline"]
_import_structure["controlnet"].extend(
[
"BlipDiffusionControlNetPipeline",
@@ -566,12 +561,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionControlNetXSPipeline,
StableDiffusionXLControlNetXSPipeline,
)
from .cosmos import (
Cosmos2TextToImagePipeline,
Cosmos2VideoToWorldPipeline,
CosmosTextToWorldPipeline,
CosmosVideoToWorldPipeline,
)
from .cosmos import CosmosTextToWorldPipeline, CosmosVideoToWorldPipeline
from .deepfloyd_if import (
IFImg2ImgPipeline,
IFImg2ImgSuperResolutionPipeline,
@@ -22,8 +22,6 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_cosmos2_text2image"] = ["Cosmos2TextToImagePipeline"]
_import_structure["pipeline_cosmos2_video2world"] = ["Cosmos2VideoToWorldPipeline"]
_import_structure["pipeline_cosmos_text2world"] = ["CosmosTextToWorldPipeline"]
_import_structure["pipeline_cosmos_video2world"] = ["CosmosVideoToWorldPipeline"]
@@ -35,8 +33,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_cosmos2_text2image import Cosmos2TextToImagePipeline
from .pipeline_cosmos2_video2world import Cosmos2VideoToWorldPipeline
from .pipeline_cosmos_text2world import CosmosTextToWorldPipeline
from .pipeline_cosmos_video2world import CosmosVideoToWorldPipeline
@@ -1,673 +0,0 @@
# Copyright 2025 The NVIDIA Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import T5EncoderModel, T5TokenizerFast
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...models import AutoencoderKLWan, CosmosTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import CosmosImagePipelineOutput
if is_cosmos_guardrail_available():
from cosmos_guardrail import CosmosSafetyChecker
else:
class CosmosSafetyChecker:
def __init__(self, *args, **kwargs):
raise ImportError(
"`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`."
)
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from diffusers import Cosmos2TextToImagePipeline
>>> # Available checkpoints: nvidia/Cosmos-Predict2-2B-Text2Image, nvidia/Cosmos-Predict2-14B-Text2Image
>>> model_id = "nvidia/Cosmos-Predict2-2B-Text2Image"
>>> pipe = Cosmos2TextToImagePipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> prompt = "A close-up shot captures a vibrant yellow scrubber vigorously working on a grimy plate, its bristles moving in circular motions to lift stubborn grease and food residue. The dish, once covered in remnants of a hearty meal, gradually reveals its original glossy surface. Suds form and bubble around the scrubber, creating a satisfying visual of cleanliness in progress. The sound of scrubbing fills the air, accompanied by the gentle clinking of the dish against the sink. As the scrubber continues its task, the dish transforms, gleaming under the bright kitchen lights, symbolizing the triumph of cleanliness over mess."
>>> negative_prompt = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality."
>>> output = pipe(
... prompt=prompt, negative_prompt=negative_prompt, generator=torch.Generator().manual_seed(1)
... ).images[0]
>>> output.save("output.png")
```
"""
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class Cosmos2TextToImagePipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using [Cosmos Predict2](https://github.com/nvidia-cosmos/cosmos-predict2).
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Args:
text_encoder ([`T5EncoderModel`]):
Frozen text-encoder. Cosmos uses
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
[t5-11b](https://huggingface.co/google-t5/t5-11b) variant.
tokenizer (`T5TokenizerFast`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
transformer ([`CosmosTransformer3DModel`]):
Conditional Transformer to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLWan`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
# We mark safety_checker as optional here to get around some test failures, but it is not really optional
_optional_components = ["safety_checker"]
def __init__(
self,
text_encoder: T5EncoderModel,
tokenizer: T5TokenizerFast,
transformer: CosmosTransformer3DModel,
vae: AutoencoderKLWan,
scheduler: FlowMatchEulerDiscreteScheduler,
safety_checker: CosmosSafetyChecker = None,
):
super().__init__()
if safety_checker is None:
safety_checker = CosmosSafetyChecker()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
safety_checker=safety_checker,
)
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
self.sigma_max = 80.0
self.sigma_min = 0.002
self.sigma_data = 1.0
self.final_sigmas_type = "sigma_min"
if self.scheduler is not None:
self.scheduler.register_to_config(
sigma_max=self.sigma_max,
sigma_min=self.sigma_min,
sigma_data=self.sigma_data,
final_sigmas_type=self.final_sigmas_type,
)
# Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline._get_t5_prompt_embeds
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
max_sequence_length: int = 512,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
return_length=True,
return_offsets_mapping=False,
)
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask.bool().to(device)
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(
text_input_ids.to(device), attention_mask=prompt_attention_mask
).last_hidden_state
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
lengths = prompt_attention_mask.sum(dim=1).cpu()
for i, length in enumerate(lengths):
prompt_embeds[i, length:] = 0
return prompt_embeds
# Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.encode_prompt with num_videos_per_prompt->num_images_per_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
do_classifier_free_guidance: bool = True,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
max_sequence_length: int = 512,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
Whether to use classifier free guidance or not.
num_images_per_prompt (`int`, *optional*, defaults to 1):
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
device: (`torch.device`, *optional*):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds = self._get_t5_prompt_embeds(
prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds = self._get_t5_prompt_embeds(
prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = negative_prompt_embeds.shape
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds, negative_prompt_embeds
def prepare_latents(
self,
batch_size: int,
num_channels_latents: 16,
height: int = 768,
width: int = 1360,
num_frames: int = 1,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if latents is not None:
return latents.to(device=device, dtype=dtype) * self.scheduler.config.sigma_max
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
latent_height = height // self.vae_scale_factor_spatial
latent_width = width // self.vae_scale_factor_spatial
shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
return latents * self.scheduler.config.sigma_max
# Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs
def check_inputs(
self,
prompt,
height,
width,
prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0
@property
def num_timesteps(self):
return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 768,
width: int = 1360,
num_inference_steps: int = 35,
guidance_scale: float = 7.0,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
):
r"""
The call function to the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`int`, defaults to `768`):
The height in pixels of the generated image.
width (`int`, defaults to `1360`):
The width in pixels of the generated image.
num_inference_steps (`int`, defaults to `35`):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, defaults to `7.0`):
Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`CosmosImagePipelineOutput`] instead of a plain tuple.
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
Examples:
Returns:
[`~CosmosImagePipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`CosmosImagePipelineOutput`] is returned, otherwise a `tuple` is returned
where the first element is a list with the generated images and the second element is a list of `bool`s
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
"""
if self.safety_checker is None:
raise ValueError(
f"You have disabled the safety checker for {self.__class__}. This is in violation of the "
"[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). "
f"Please ensure that you are compliant with the license agreement."
)
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
num_frames = 1
# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs)
self._guidance_scale = guidance_scale
self._current_timestep = None
self._interrupt = False
device = self._execution_device
if self.safety_checker is not None:
self.safety_checker.to(device)
if prompt is not None:
prompt_list = [prompt] if isinstance(prompt, str) else prompt
for p in prompt_list:
if not self.safety_checker.check_text_safety(p):
raise ValueError(
f"Cosmos Guardrail detected unsafe text in the prompt: {p}. Please ensure that the "
f"prompt abides by the NVIDIA Open Model License Agreement."
)
self.safety_checker.to("cpu")
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# 3. Encode input prompt
(
prompt_embeds,
negative_prompt_embeds,
) = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
num_images_per_prompt=num_images_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
device=device,
max_sequence_length=max_sequence_length,
)
# 4. Prepare timesteps
sigmas_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
sigmas = torch.linspace(0, 1, num_inference_steps, dtype=sigmas_dtype)
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, device=device, sigmas=sigmas)
if self.scheduler.config.get("final_sigmas_type", "zero") == "sigma_min":
# Replace the last sigma (which is zero) with the minimum sigma value
self.scheduler.sigmas[-1] = self.scheduler.sigmas[-2]
# 5. Prepare latent variables
transformer_dtype = self.transformer.dtype
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
num_frames,
torch.float32,
device,
generator,
latents,
)
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
# 6. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
current_sigma = self.scheduler.sigmas[i]
current_t = current_sigma / (current_sigma + 1)
c_in = 1 - current_t
c_skip = 1 - current_t
c_out = -current_t
timestep = current_t.expand(latents.shape[0]).to(transformer_dtype) # [B, 1, T, 1, 1]
latent_model_input = latents * c_in
latent_model_input = latent_model_input.to(transformer_dtype)
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
padding_mask=padding_mask,
return_dict=False,
)[0]
noise_pred = (c_skip * latents + c_out * noise_pred.float()).to(transformer_dtype)
if self.do_classifier_free_guidance:
noise_pred_uncond = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
padding_mask=padding_mask,
return_dict=False,
)[0]
noise_pred_uncond = (c_skip * latents + c_out * noise_pred_uncond.float()).to(transformer_dtype)
noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_uncond)
noise_pred = (latents - noise_pred) / current_sigma
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if not output_type == "latent":
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
latents.device, latents.dtype
)
latents = latents / latents_std / self.scheduler.config.sigma_data + latents_mean
video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
if self.safety_checker is not None:
self.safety_checker.to(device)
video = self.video_processor.postprocess_video(video, output_type="np")
video = (video * 255).astype(np.uint8)
video_batch = []
for vid in video:
vid = self.safety_checker.check_video_safety(vid)
video_batch.append(vid)
video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1
video = torch.from_numpy(video).permute(0, 4, 1, 2, 3)
video = self.video_processor.postprocess_video(video, output_type=output_type)
self.safety_checker.to("cpu")
else:
video = self.video_processor.postprocess_video(video, output_type=output_type)
image = [batch[0] for batch in video]
if isinstance(video, torch.Tensor):
image = torch.stack(image)
elif isinstance(video, np.ndarray):
image = np.stack(image)
else:
image = latents[:, :, 0]
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return CosmosImagePipelineOutput(images=image)
@@ -1,792 +0,0 @@
# Copyright 2025 The NVIDIA Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import T5EncoderModel, T5TokenizerFast
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput
from ...models import AutoencoderKLWan, CosmosTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import CosmosPipelineOutput
if is_cosmos_guardrail_available():
from cosmos_guardrail import CosmosSafetyChecker
else:
class CosmosSafetyChecker:
def __init__(self, *args, **kwargs):
raise ImportError(
"`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`."
)
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from diffusers import Cosmos2VideoToWorldPipeline
>>> from diffusers.utils import export_to_video, load_image
>>> # Available checkpoints: nvidia/Cosmos-Predict2-2B-Video2World, nvidia/Cosmos-Predict2-14B-Video2World
>>> model_id = "nvidia/Cosmos-Predict2-2B-Video2World"
>>> pipe = Cosmos2VideoToWorldPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> prompt = "A close-up shot captures a vibrant yellow scrubber vigorously working on a grimy plate, its bristles moving in circular motions to lift stubborn grease and food residue. The dish, once covered in remnants of a hearty meal, gradually reveals its original glossy surface. Suds form and bubble around the scrubber, creating a satisfying visual of cleanliness in progress. The sound of scrubbing fills the air, accompanied by the gentle clinking of the dish against the sink. As the scrubber continues its task, the dish transforms, gleaming under the bright kitchen lights, symbolizing the triumph of cleanliness over mess."
>>> negative_prompt = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality."
>>> image = load_image(
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yellow-scrubber.png"
... )
>>> video = pipe(
... image=image, prompt=prompt, negative_prompt=negative_prompt, generator=torch.Generator().manual_seed(1)
... ).frames[0]
>>> export_to_video(video, "output.mp4", fps=16)
```
"""
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class Cosmos2VideoToWorldPipeline(DiffusionPipeline):
r"""
Pipeline for video-to-world generation using [Cosmos Predict2](https://github.com/nvidia-cosmos/cosmos-predict2).
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Args:
text_encoder ([`T5EncoderModel`]):
Frozen text-encoder. Cosmos uses
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
[t5-11b](https://huggingface.co/google-t5/t5-11b) variant.
tokenizer (`T5TokenizerFast`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
transformer ([`CosmosTransformer3DModel`]):
Conditional Transformer to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLWan`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
# We mark safety_checker as optional here to get around some test failures, but it is not really optional
_optional_components = ["safety_checker"]
def __init__(
self,
text_encoder: T5EncoderModel,
tokenizer: T5TokenizerFast,
transformer: CosmosTransformer3DModel,
vae: AutoencoderKLWan,
scheduler: FlowMatchEulerDiscreteScheduler,
safety_checker: CosmosSafetyChecker = None,
):
super().__init__()
if safety_checker is None:
safety_checker = CosmosSafetyChecker()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
safety_checker=safety_checker,
)
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
self.sigma_max = 80.0
self.sigma_min = 0.002
self.sigma_data = 1.0
self.final_sigmas_type = "sigma_min"
if self.scheduler is not None:
self.scheduler.register_to_config(
sigma_max=self.sigma_max,
sigma_min=self.sigma_min,
sigma_data=self.sigma_data,
final_sigmas_type=self.final_sigmas_type,
)
# Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline._get_t5_prompt_embeds
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
max_sequence_length: int = 512,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
return_length=True,
return_offsets_mapping=False,
)
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask.bool().to(device)
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(
text_input_ids.to(device), attention_mask=prompt_attention_mask
).last_hidden_state
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
lengths = prompt_attention_mask.sum(dim=1).cpu()
for i, length in enumerate(lengths):
prompt_embeds[i, length:] = 0
return prompt_embeds
# Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.encode_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
do_classifier_free_guidance: bool = True,
num_videos_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
max_sequence_length: int = 512,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
Whether to use classifier free guidance or not.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
device: (`torch.device`, *optional*):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds = self._get_t5_prompt_embeds(
prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds = self._get_t5_prompt_embeds(
prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = negative_prompt_embeds.shape
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
return prompt_embeds, negative_prompt_embeds
def prepare_latents(
self,
video: torch.Tensor,
batch_size: int,
num_channels_latents: 16,
height: int = 704,
width: int = 1280,
num_frames: int = 93,
do_classifier_free_guidance: bool = True,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
num_cond_frames = video.size(2)
if num_cond_frames >= num_frames:
# Take the last `num_frames` frames for conditioning
num_cond_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
video = video[:, :, -num_frames:]
else:
num_cond_latent_frames = (num_cond_frames - 1) // self.vae_scale_factor_temporal + 1
num_padding_frames = num_frames - num_cond_frames
last_frame = video[:, :, -1:]
padding = last_frame.repeat(1, 1, num_padding_frames, 1, 1)
video = torch.cat([video, padding], dim=2)
if isinstance(generator, list):
init_latents = [
retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator=generator[i])
for i in range(batch_size)
]
else:
init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
init_latents = torch.cat(init_latents, dim=0).to(dtype)
latents_mean = (
torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype)
)
latents_std = (
torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype)
)
init_latents = (init_latents - latents_mean) / latents_std * self.scheduler.config.sigma_data
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
latent_height = height // self.vae_scale_factor_spatial
latent_width = width // self.vae_scale_factor_spatial
shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device=device, dtype=dtype)
latents = latents * self.scheduler.config.sigma_max
padding_shape = (batch_size, 1, num_latent_frames, latent_height, latent_width)
ones_padding = latents.new_ones(padding_shape)
zeros_padding = latents.new_zeros(padding_shape)
cond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1)
cond_indicator[:, :, :num_cond_latent_frames] = 1.0
cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding
uncond_indicator = uncond_mask = None
if do_classifier_free_guidance:
uncond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1)
uncond_indicator[:, :, :num_cond_latent_frames] = 1.0
uncond_mask = uncond_indicator * ones_padding + (1 - uncond_indicator) * zeros_padding
return latents, init_latents, cond_indicator, uncond_indicator, cond_mask, uncond_mask
# Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs
def check_inputs(
self,
prompt,
height,
width,
prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0
@property
def num_timesteps(self):
return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
image: PipelineImageInput = None,
video: List[PipelineImageInput] = None,
prompt: Union[str, List[str]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 704,
width: int = 1280,
num_frames: int = 93,
num_inference_steps: int = 35,
guidance_scale: float = 7.0,
fps: int = 16,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
sigma_conditioning: float = 0.0001,
):
r"""
The call function to the pipeline for generation.
Args:
image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, *optional*):
The image to be used as a conditioning input for the video generation.
video (`List[PIL.Image.Image]`, `np.ndarray`, `torch.Tensor`, *optional*):
The video to be used as a conditioning input for the video generation.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`int`, defaults to `704`):
The height in pixels of the generated image.
width (`int`, defaults to `1280`):
The width in pixels of the generated image.
num_frames (`int`, defaults to `93`):
The number of frames in the generated video.
num_inference_steps (`int`, defaults to `35`):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, defaults to `7.0`):
Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`.
fps (`int`, defaults to `16`):
The frames per second of the generated video.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple.
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int`, defaults to `512`):
The maximum number of tokens in the prompt. If the prompt exceeds this length, it will be truncated. If
the prompt is shorter than this length, it will be padded.
sigma_conditioning (`float`, defaults to `0.0001`):
The sigma value used for scaling conditioning latents. Ideally, it should not be changed or should be
set to a small value close to zero.
Examples:
Returns:
[`~CosmosPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`CosmosPipelineOutput`] is returned, otherwise a `tuple` is returned where
the first element is a list with the generated images and the second element is a list of `bool`s
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
"""
if self.safety_checker is None:
raise ValueError(
f"You have disabled the safety checker for {self.__class__}. This is in violation of the "
"[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). "
f"Please ensure that you are compliant with the license agreement."
)
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs)
self._guidance_scale = guidance_scale
self._current_timestep = None
self._interrupt = False
device = self._execution_device
if self.safety_checker is not None:
self.safety_checker.to(device)
if prompt is not None:
prompt_list = [prompt] if isinstance(prompt, str) else prompt
for p in prompt_list:
if not self.safety_checker.check_text_safety(p):
raise ValueError(
f"Cosmos Guardrail detected unsafe text in the prompt: {p}. Please ensure that the "
f"prompt abides by the NVIDIA Open Model License Agreement."
)
self.safety_checker.to("cpu")
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# 3. Encode input prompt
(
prompt_embeds,
negative_prompt_embeds,
) = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
device=device,
max_sequence_length=max_sequence_length,
)
# 4. Prepare timesteps
sigmas_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
sigmas = torch.linspace(0, 1, num_inference_steps, dtype=sigmas_dtype)
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, device=device, sigmas=sigmas)
if self.scheduler.config.final_sigmas_type == "sigma_min":
# Replace the last sigma (which is zero) with the minimum sigma value
self.scheduler.sigmas[-1] = self.scheduler.sigmas[-2]
# 5. Prepare latent variables
vae_dtype = self.vae.dtype
transformer_dtype = self.transformer.dtype
if image is not None:
video = self.video_processor.preprocess(image, height, width).unsqueeze(2)
else:
video = self.video_processor.preprocess_video(video, height, width)
video = video.to(device=device, dtype=vae_dtype)
num_channels_latents = self.transformer.config.in_channels - 1
latents, conditioning_latents, cond_indicator, uncond_indicator, cond_mask, uncond_mask = self.prepare_latents(
video,
batch_size * num_videos_per_prompt,
num_channels_latents,
height,
width,
num_frames,
self.do_classifier_free_guidance,
torch.float32,
device,
generator,
latents,
)
unconditioning_latents = None
cond_mask = cond_mask.to(transformer_dtype)
if self.do_classifier_free_guidance:
uncond_mask = uncond_mask.to(transformer_dtype)
unconditioning_latents = conditioning_latents
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
sigma_conditioning = torch.tensor(sigma_conditioning, dtype=torch.float32, device=device)
t_conditioning = sigma_conditioning / (sigma_conditioning + 1)
# 6. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
current_sigma = self.scheduler.sigmas[i]
current_t = current_sigma / (current_sigma + 1)
c_in = 1 - current_t
c_skip = 1 - current_t
c_out = -current_t
timestep = current_t.view(1, 1, 1, 1, 1).expand(
latents.size(0), -1, latents.size(2), -1, -1
) # [B, 1, T, 1, 1]
cond_latent = latents * c_in
cond_latent = cond_indicator * conditioning_latents + (1 - cond_indicator) * cond_latent
cond_latent = cond_latent.to(transformer_dtype)
cond_timestep = cond_indicator * t_conditioning + (1 - cond_indicator) * timestep
cond_timestep = cond_timestep.to(transformer_dtype)
noise_pred = self.transformer(
hidden_states=cond_latent,
timestep=cond_timestep,
encoder_hidden_states=prompt_embeds,
fps=fps,
condition_mask=cond_mask,
padding_mask=padding_mask,
return_dict=False,
)[0]
noise_pred = (c_skip * latents + c_out * noise_pred.float()).to(transformer_dtype)
noise_pred = cond_indicator * conditioning_latents + (1 - cond_indicator) * noise_pred
if self.do_classifier_free_guidance:
uncond_latent = latents * c_in
uncond_latent = uncond_indicator * unconditioning_latents + (1 - uncond_indicator) * uncond_latent
uncond_latent = uncond_latent.to(transformer_dtype)
uncond_timestep = uncond_indicator * t_conditioning + (1 - uncond_indicator) * timestep
uncond_timestep = uncond_timestep.to(transformer_dtype)
noise_pred_uncond = self.transformer(
hidden_states=uncond_latent,
timestep=uncond_timestep,
encoder_hidden_states=negative_prompt_embeds,
fps=fps,
condition_mask=uncond_mask,
padding_mask=padding_mask,
return_dict=False,
)[0]
noise_pred_uncond = (c_skip * latents + c_out * noise_pred_uncond.float()).to(transformer_dtype)
noise_pred_uncond = (
uncond_indicator * unconditioning_latents + (1 - uncond_indicator) * noise_pred_uncond
)
noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_uncond)
noise_pred = (latents - noise_pred) / current_sigma
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if not output_type == "latent":
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = (
torch.tensor(self.vae.config.latents_std)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents = latents * latents_std / self.scheduler.config.sigma_data + latents_mean
video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
if self.safety_checker is not None:
self.safety_checker.to(device)
video = self.video_processor.postprocess_video(video, output_type="np")
video = (video * 255).astype(np.uint8)
video_batch = []
for vid in video:
vid = self.safety_checker.check_video_safety(vid)
video_batch.append(vid)
video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1
video = torch.from_numpy(video).permute(0, 4, 1, 2, 3)
video = self.video_processor.postprocess_video(video, output_type=output_type)
self.safety_checker.to("cpu")
else:
video = self.video_processor.postprocess_video(video, output_type=output_type)
else:
video = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return CosmosPipelineOutput(frames=video)
@@ -131,7 +131,7 @@ def retrieve_timesteps(
class CosmosTextToWorldPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-world generation using [Cosmos Predict1](https://github.com/nvidia-cosmos/cosmos-predict1).
Pipeline for text-to-video generation using [Cosmos](https://github.com/NVIDIA/Cosmos).
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
@@ -426,12 +426,12 @@ class CosmosTextToWorldPipeline(DiffusionPipeline):
The height in pixels of the generated image.
width (`int`, defaults to `1280`):
The width in pixels of the generated image.
num_frames (`int`, defaults to `121`):
num_frames (`int`, defaults to `129`):
The number of frames in the generated video.
num_inference_steps (`int`, defaults to `36`):
num_inference_steps (`int`, defaults to `50`):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, defaults to `7.0`):
guidance_scale (`float`, defaults to `6.0`):
Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
@@ -457,6 +457,9 @@ class CosmosTextToWorldPipeline(DiffusionPipeline):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
@@ -174,8 +174,7 @@ def retrieve_latents(
class CosmosVideoToWorldPipeline(DiffusionPipeline):
r"""
Pipeline for image-to-world and video-to-world generation using [Cosmos
Predict-1](https://github.com/nvidia-cosmos/cosmos-predict1).
Pipeline for image-to-video and video-to-video generation using [Cosmos](https://github.com/NVIDIA/Cosmos).
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
@@ -542,12 +541,12 @@ class CosmosVideoToWorldPipeline(DiffusionPipeline):
The height in pixels of the generated image.
width (`int`, defaults to `1280`):
The width in pixels of the generated image.
num_frames (`int`, defaults to `121`):
num_frames (`int`, defaults to `129`):
The number of frames in the generated video.
num_inference_steps (`int`, defaults to `36`):
num_inference_steps (`int`, defaults to `50`):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, defaults to `7.0`):
guidance_scale (`float`, defaults to `6.0`):
Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
@@ -573,6 +572,9 @@ class CosmosVideoToWorldPipeline(DiffusionPipeline):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
@@ -1,20 +1,14 @@
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL.Image
import torch
from diffusers.utils import BaseOutput, get_logger
logger = get_logger(__name__)
from diffusers.utils import BaseOutput
@dataclass
class CosmosPipelineOutput(BaseOutput):
r"""
Output class for Cosmos any-to-world/video pipelines.
Output class for Cosmos pipelines.
Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
@@ -24,17 +18,3 @@ class CosmosPipelineOutput(BaseOutput):
"""
frames: torch.Tensor
@dataclass
class CosmosImagePipelineOutput(BaseOutput):
"""
Output class for Cosmos any-to-image pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
-8
View File
@@ -247,14 +247,6 @@ def _set_state_dict_into_text_encoder(
set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default")
def _collate_lora_metadata(modules_to_save: Dict[str, torch.nn.Module]) -> Dict[str, Any]:
metadatas = {}
for module_name, module in modules_to_save.items():
if module is not None:
metadatas[f"{module_name}_lora_adapter_metadata"] = module.peft_config["default"].to_dict()
return metadatas
def compute_density_for_timestep_sampling(
weighting_scheme: str,
batch_size: int,
File diff suppressed because it is too large Load Diff
-300
View File
@@ -1,300 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Doc utilities: Utilities related to documentation
Adapted from:
https://github.com/huggingface/transformers/blob/5a95ed5ca0826c867e35e52f698db4d8fc907bcb/src/transformers/utils/doc.py
"""
import functools
import inspect
import re
import textwrap
import types
from collections import OrderedDict
from ..pipelines.auto_pipeline import AUTO_TEXT2IMAGE_PIPELINES_MAPPING
def get_docstring_indentation_level(func):
"""Return the indentation level of the start of the docstring of a class or function (or method)."""
# We assume classes are always defined in the global scope
if inspect.isclass(func):
return 4
source = inspect.getsource(func)
first_line = source.splitlines()[0]
function_def_level = len(first_line) - len(first_line.lstrip())
return 4 + function_def_level
def add_start_docstrings(*docstr):
def docstring_decorator(fn):
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
return fn
return docstring_decorator
def add_start_docstrings_to_model_forward(*docstr):
def docstring_decorator(fn):
class_name = f"[`{fn.__qualname__.split('.')[0]}`]"
intro = rf""" The {class_name} forward method, overrides the `__call__` special method.
<Tip>
Although the recipe for forward pass needs to be defined within this function, one should call the [`Module`]
instance afterwards instead of this since the former takes care of running the pre and post processing steps while
the latter silently ignores them.
</Tip>
"""
correct_indentation = get_docstring_indentation_level(fn)
current_doc = fn.__doc__ if fn.__doc__ is not None else ""
try:
first_non_empty = next(line for line in current_doc.splitlines() if line.strip() != "")
doc_indentation = len(first_non_empty) - len(first_non_empty.lstrip())
except StopIteration:
doc_indentation = correct_indentation
docs = docstr
# In this case, the correct indentation level (class method, 2 Python levels) was respected, and we should
# correctly reindent everything. Otherwise, the doc uses a single indentation level
if doc_indentation == 4 + correct_indentation:
docs = [textwrap.indent(textwrap.dedent(doc), " " * correct_indentation) for doc in docstr]
intro = textwrap.indent(textwrap.dedent(intro), " " * correct_indentation)
docstring = "".join(docs) + current_doc
fn.__doc__ = intro + docstring
return fn
return docstring_decorator
def add_end_docstrings(*docstr):
def docstring_decorator(fn):
fn.__doc__ = (fn.__doc__ if fn.__doc__ is not None else "") + "".join(docstr)
return fn
return docstring_decorator
PT_RETURN_INTRODUCTION = r"""
Returns:
[`{full_output_type}`] or `tuple(torch.FloatTensor)`: A [`{full_output_type}`] or a tuple of
`torch.FloatTensor` (if `return_dict=False` is passed) comprising various
elements depending on the model and inputs.
"""
TEXT_TO_IMAGE_PIPELINE_CLASSES = list({p[0] for p in AUTO_TEXT2IMAGE_PIPELINES_MAPPING})
def _get_indent(t):
"""Returns the indentation in the first line of t"""
search = re.search(r"^(\s*)\S", t)
return "" if search is None else search.groups()[0]
def _convert_output_args_doc(output_args_doc):
"""Convert output_args_doc to display properly."""
# Split output_arg_doc in blocks argument/description
indent = _get_indent(output_args_doc)
blocks = []
current_block = ""
for line in output_args_doc.split("\n"):
# If the indent is the same as the beginning, the line is the name of new arg.
if _get_indent(line) == indent:
if len(current_block) > 0:
blocks.append(current_block[:-1])
current_block = f"{line}\n"
else:
# Otherwise it's part of the description of the current arg.
# We need to remove 2 spaces to the indentation.
current_block += f"{line[2:]}\n"
blocks.append(current_block[:-1])
# Format each block for proper rendering
for i in range(len(blocks)):
blocks[i] = re.sub(r"^(\s+)(\S+)(\s+)", r"\1- **\2**\3", blocks[i])
blocks[i] = re.sub(r":\s*\n\s*(\S)", r" -- \1", blocks[i])
return "\n".join(blocks)
def _prepare_output_docstrings(output_type, config_class, min_indent=None, add_intro=True):
"""
Prepares the return part of the docstring using `output_type`.
"""
output_docstring = output_type.__doc__
params_docstring = None
if output_docstring is not None:
# Remove the head of the docstring to keep the list of args only
lines = output_docstring.split("\n")
i = 0
while i < len(lines) and re.search(r"^\s*(Args|Parameters):\s*$", lines[i]) is None:
i += 1
if i < len(lines):
params_docstring = "\n".join(lines[(i + 1) :])
params_docstring = _convert_output_args_doc(params_docstring)
elif add_intro:
raise ValueError(
f"No `Args` or `Parameters` section is found in the docstring of `{output_type.__name__}`. Make sure it has "
"docstring and contain either `Args` or `Parameters`."
)
# Add the return introduction
if add_intro:
full_output_type = f"{output_type.__module__}.{output_type.__name__}"
intro = PT_RETURN_INTRODUCTION
intro = intro.format(full_output_type=full_output_type, config_class=config_class)
else:
full_output_type = str(output_type)
intro = f"\nReturns:\n `{full_output_type}`"
if params_docstring is not None:
intro += ":\n"
result = intro
if params_docstring is not None:
result += params_docstring
# Apply minimum indent if necessary
if min_indent is not None:
lines = result.split("\n")
# Find the indent of the first nonempty line
i = 0
while len(lines[i]) == 0:
i += 1
indent = len(_get_indent(lines[i]))
# If too small, add indentation to all nonempty lines
if indent < min_indent:
to_add = " " * (min_indent - indent)
lines = [(f"{to_add}{line}" if len(line) > 0 else line) for line in lines]
result = "\n".join(lines)
return result
FAKE_MODEL_DISCLAIMER = """
<Tip warning={true}>
This example uses a random model as the real ones are all very big. To get proper results, you should use
{real_checkpoint} instead of {fake_checkpoint}. If you get out-of-memory when loading that checkpoint, you can
refer to our optimization docs.
</Tip>
"""
PT_TEXT_TO_IMAGE_SAMPLE = r"""
Example:
```python
>>> from diffusers import DiffusionPipeline
>>> import torch
>>> # If memory doesn't allow, enable optimizations like `enable_model_cpu_offload()`.
>>> pipe = DiffusionPipeline.from_pretrained("{checkpoint}", torch_dtype=torch.bfloat16).to("cuda")
>>> prompt = "a photo of a cute dog."
>>> image = pipe(prompt).images[0] # Configure other pipe call arguments as needed.
```
"""
PT_SAMPLE_DOCSTRINGS = {
"Text2Image": PT_TEXT_TO_IMAGE_SAMPLE
}
PIPELINE_TASKS_TO_SAMPLE_DOCSTRINGS = OrderedDict(["text-to-image", PT_TEXT_TO_IMAGE_SAMPLE])
def filter_outputs_from_example(docstring, **kwargs):
"""
Removes the lines testing an output with the doctest syntax in a code sample when it's set to `None`.
"""
for key, value in kwargs.items():
if value is not None:
continue
doc_key = "{" + key + "}"
docstring = re.sub(rf"\n([^\n]+)\n\s+{doc_key}\n", "\n", docstring)
return docstring
def add_code_sample_docstrings(
*docstr,
checkpoint=None,
output_type=None,
config_class=None,
model_cls=None,
):
def docstring_decorator(fn):
# model_class defaults to function's class if not specified otherwise
model_class = fn.__qualname__.split(".")[0] if model_cls is None else model_cls
sample_docstrings = PT_SAMPLE_DOCSTRINGS
# putting all kwargs for docstrings in a dict to be used
# with the `.format(**doc_kwargs)`. Note that string might
# be formatted with non-existing keys, which is fine.
doc_kwargs = {
"checkpoint": checkpoint,
"true": "{true}", # For <Tip warning={true}> syntax that conflicts with formatting.
}
if model_class in TEXT_TO_IMAGE_PIPELINE_CLASSES:
code_sample = sample_docstrings["Text2Image"]
else:
raise ValueError(f"Docstring can't be built for model {model_class}")
code_sample = filter_outputs_from_example(code_sample)
func_doc = (fn.__doc__ or "") + "".join(docstr)
output_doc = "" if output_type is None else _prepare_output_docstrings(output_type, config_class)
built_doc = code_sample.format(**doc_kwargs)
fn.__doc__ = func_doc + output_doc + built_doc
return fn
return docstring_decorator
def replace_return_docstrings(output_type=None, config_class=None):
def docstring_decorator(fn):
func_doc = fn.__doc__
lines = func_doc.split("\n")
i = 0
while i < len(lines) and re.search(r"^\s*Returns?:\s*$", lines[i]) is None:
i += 1
if i < len(lines):
indent = len(_get_indent(lines[i]))
lines[i] = _prepare_output_docstrings(output_type, config_class, min_indent=indent)
func_doc = "\n".join(lines)
else:
raise ValueError(
f"The function {fn} should have an empty 'Return:' or 'Returns:' in its docstring as placeholder, "
f"current docstring is:\n{func_doc}"
)
fn.__doc__ = func_doc
return fn
return docstring_decorator
def copy_func(f):
"""Returns a copy of a function f."""
# Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard)
g = types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__, closure=f.__closure__)
g = functools.update_wrapper(g, f)
g.__kwdefaults__ = f.__kwdefaults__
return g
@@ -422,36 +422,6 @@ class ConsisIDPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class Cosmos2TextToImagePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class Cosmos2VideoToWorldPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class CosmosTextToWorldPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
+2 -5
View File
@@ -359,8 +359,5 @@ def _load_sft_state_dict_metadata(model_file: str):
metadata = f.metadata() or {}
metadata.pop("format", None)
if metadata:
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
return json.loads(raw) if raw else None
else:
return None
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
return json.loads(raw) if raw else None
-39
View File
@@ -1736,45 +1736,6 @@ class ModelTesterMixin:
f"AutoModel forward pass diff: {max_diff} exceeds threshold {expected_max_diff}",
)
@parameterized.expand(
[
(-1, "You can't pass device_map as a negative int"),
("foo", "When passing device_map as a string, the value needs to be a device name"),
]
)
def test_wrong_device_map_raises_error(self, device_map, msg_substring):
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
with tempfile.TemporaryDirectory() as tmpdir:
model.save_pretrained(tmpdir)
with self.assertRaises(ValueError) as err_ctx:
_ = self.model_class.from_pretrained(tmpdir, device_map=device_map)
assert msg_substring in str(err_ctx.exception)
@parameterized.expand([0, "cuda", torch.device("cuda")])
@require_torch_gpu
def test_passing_non_dict_device_map_works(self, device_map):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).eval()
with tempfile.TemporaryDirectory() as tmpdir:
model.save_pretrained(tmpdir)
loaded_model = self.model_class.from_pretrained(tmpdir, device_map=device_map)
_ = loaded_model(**inputs_dict)
@parameterized.expand([("", "cuda"), ("", torch.device("cuda"))])
@require_torch_gpu
def test_passing_dict_device_map_works(self, name, device):
# There are other valid dict-based `device_map` values too. It's best to refer to
# the docs for those: https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#the-devicemap.
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).eval()
device_map = {name: device}
with tempfile.TemporaryDirectory() as tmpdir:
model.save_pretrained(tmpdir)
loaded_model = self.model_class.from_pretrained(tmpdir, device_map=device_map)
_ = loaded_model(**inputs_dict)
@is_staging_test
class ModelPushToHubTester(unittest.TestCase):
@@ -46,6 +46,7 @@ from diffusers.utils.testing_utils import (
require_peft_backend,
require_torch_accelerator,
require_torch_accelerator_with_fp16,
require_torch_gpu,
skip_mps,
slow,
torch_all_close,
@@ -1083,6 +1084,42 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@parameterized.expand(
[
(-1, "You can't pass device_map as a negative int"),
("foo", "When passing device_map as a string, the value needs to be a device name"),
]
)
def test_wrong_device_map_raises_error(self, device_map, msg_substring):
with self.assertRaises(ValueError) as err_ctx:
_ = self.model_class.from_pretrained(
"hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map=device_map
)
assert msg_substring in str(err_ctx.exception)
@parameterized.expand([0, "cuda", torch.device("cuda"), torch.device("cuda:0")])
@require_torch_gpu
def test_passing_non_dict_device_map_works(self, device_map):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained(
"hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map=device_map
)
output = loaded_model(**inputs_dict)
assert output.sample.shape == (4, 4, 16, 16)
@parameterized.expand([("", "cuda"), ("", torch.device("cuda"))])
@require_torch_gpu
def test_passing_dict_device_map_works(self, name, device_map):
# There are other valid dict-based `device_map` values too. It's best to refer to
# the docs for those: https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#the-devicemap.
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained(
"hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map={name: device_map}
)
output = loaded_model(**inputs_dict)
assert output.sample.shape == (4, 4, 16, 16)
@require_peft_backend
def test_load_attn_procs_raise_warning(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -1,337 +0,0 @@
# Copyright 2024 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import json
import os
import tempfile
import unittest
import numpy as np
import torch
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import (
AutoencoderKLWan,
Cosmos2TextToImagePipeline,
CosmosTransformer3DModel,
FlowMatchEulerDiscreteScheduler,
)
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
from .cosmos_guardrail import DummyCosmosSafetyChecker
enable_full_determinism()
class Cosmos2TextToImagePipelineWrapper(Cosmos2TextToImagePipeline):
@staticmethod
def from_pretrained(*args, **kwargs):
kwargs["safety_checker"] = DummyCosmosSafetyChecker()
return Cosmos2TextToImagePipeline.from_pretrained(*args, **kwargs)
class Cosmos2TextToImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = Cosmos2TextToImagePipelineWrapper
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
supports_dduf = False
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)
transformer = CosmosTransformer3DModel(
in_channels=16,
out_channels=16,
num_attention_heads=2,
attention_head_dim=16,
num_layers=2,
mlp_ratio=2,
text_embed_dim=32,
adaln_lora_dim=4,
max_size=(4, 32, 32),
patch_size=(1, 2, 2),
rope_scale=(2.0, 1.0, 1.0),
concat_padding_mask=True,
extra_pos_embed_type="learnable",
)
torch.manual_seed(0)
vae = AutoencoderKLWan(
base_dim=3,
z_dim=16,
dim_mult=[1, 1, 1, 1],
num_res_blocks=1,
temperal_downsample=[False, True, True],
)
torch.manual_seed(0)
scheduler = FlowMatchEulerDiscreteScheduler(use_karras_sigmas=True)
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
components = {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
# We cannot run the Cosmos Guardrail for fast tests due to the large model size
"safety_checker": DummyCosmosSafetyChecker(),
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "dance monkey",
"negative_prompt": "bad quality",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 3.0,
"height": 32,
"width": 32,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
generated_image = image[0]
self.assertEqual(generated_image.shape, (3, 32, 32))
expected_video = torch.randn(3, 32, 32)
max_diff = np.abs(generated_image - expected_video).max()
self.assertLessEqual(max_diff, 1e10)
def test_callback_inputs(self):
sig = inspect.signature(self.pipeline_class.__call__)
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
has_callback_step_end = "callback_on_step_end" in sig.parameters
if not (has_callback_tensor_inputs and has_callback_step_end):
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
self.assertTrue(
hasattr(pipe, "_callback_tensor_inputs"),
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
)
def callback_inputs_subset(pipe, i, t, callback_kwargs):
# iterate over callback args
for tensor_name, tensor_value in callback_kwargs.items():
# check that we're only passing in allowed tensor inputs
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
def callback_inputs_all(pipe, i, t, callback_kwargs):
for tensor_name in pipe._callback_tensor_inputs:
assert tensor_name in callback_kwargs
# iterate over callback args
for tensor_name, tensor_value in callback_kwargs.items():
# check that we're only passing in allowed tensor inputs
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
inputs = self.get_dummy_inputs(torch_device)
# Test passing in a subset
inputs["callback_on_step_end"] = callback_inputs_subset
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
output = pipe(**inputs)[0]
# Test passing in a everything
inputs["callback_on_step_end"] = callback_inputs_all
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
output = pipe(**inputs)[0]
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
is_last = i == (pipe.num_timesteps - 1)
if is_last:
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
return callback_kwargs
inputs["callback_on_step_end"] = callback_inputs_change_tensor
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
output = pipe(**inputs)[0]
assert output.abs().sum() < 1e10
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-2)
def test_attention_slicing_forward_pass(
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
):
if not self.test_attention_slicing:
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
output_without_slicing = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=1)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing1 = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=2)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing2 = pipe(**inputs)[0]
if test_max_difference:
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
self.assertLess(
max(max_diff1, max_diff2),
expected_max_diff,
"Attention slicing should not affect the inference results",
)
def test_vae_tiling(self, expected_diff_max: float = 0.2):
generator_device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to("cpu")
pipe.set_progress_bar_config(disable=None)
# Without tiling
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_without_tiling = pipe(**inputs)[0]
# With tiling
pipe.vae.enable_tiling(
tile_sample_min_height=96,
tile_sample_min_width=96,
tile_sample_stride_height=64,
tile_sample_stride_width=64,
)
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_with_tiling = pipe(**inputs)[0]
self.assertLess(
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
expected_diff_max,
"VAE tiling should not affect the inference results",
)
def test_save_load_optional_components(self, expected_max_difference=1e-4):
self.pipeline_class._optional_components.remove("safety_checker")
super().test_save_load_optional_components(expected_max_difference=expected_max_difference)
self.pipeline_class._optional_components.append("safety_checker")
def test_serialization_with_variants(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
model_components = [
component_name
for component_name, component in pipe.components.items()
if isinstance(component, torch.nn.Module)
]
model_components.remove("safety_checker")
variant = "fp16"
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, variant=variant, safe_serialization=False)
with open(f"{tmpdir}/model_index.json", "r") as f:
config = json.load(f)
for subfolder in os.listdir(tmpdir):
if not os.path.isfile(subfolder) and subfolder in model_components:
folder_path = os.path.join(tmpdir, subfolder)
is_folder = os.path.isdir(folder_path) and subfolder in config
assert is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path))
def test_torch_dtype_dict(self):
components = self.get_dummy_components()
if not components:
self.skipTest("No dummy components defined.")
pipe = self.pipeline_class(**components)
specified_key = next(iter(components.keys()))
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
pipe.save_pretrained(tmpdirname, safe_serialization=False)
torch_dtype_dict = {specified_key: torch.bfloat16, "default": torch.float16}
loaded_pipe = self.pipeline_class.from_pretrained(
tmpdirname, safety_checker=DummyCosmosSafetyChecker(), torch_dtype=torch_dtype_dict
)
for name, component in loaded_pipe.components.items():
if name == "safety_checker":
continue
if isinstance(component, torch.nn.Module) and hasattr(component, "dtype"):
expected_dtype = torch_dtype_dict.get(name, torch_dtype_dict.get("default", torch.float32))
self.assertEqual(
component.dtype,
expected_dtype,
f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}",
)
@unittest.skip(
"The pipeline should not be runnable without a safety checker. The test creates a pipeline without passing in "
"a safety checker, which makes the pipeline default to the actual Cosmos Guardrail. The Cosmos Guardrail is "
"too large and slow to run on CI."
)
def test_encode_prompt_works_in_isolation(self):
pass
@@ -1,351 +0,0 @@
# Copyright 2024 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import json
import os
import tempfile
import unittest
import numpy as np
import PIL.Image
import torch
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import (
AutoencoderKLWan,
Cosmos2VideoToWorldPipeline,
CosmosTransformer3DModel,
FlowMatchEulerDiscreteScheduler,
)
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
from .cosmos_guardrail import DummyCosmosSafetyChecker
enable_full_determinism()
class Cosmos2VideoToWorldPipelineWrapper(Cosmos2VideoToWorldPipeline):
@staticmethod
def from_pretrained(*args, **kwargs):
kwargs["safety_checker"] = DummyCosmosSafetyChecker()
return Cosmos2VideoToWorldPipeline.from_pretrained(*args, **kwargs)
class Cosmos2VideoToWorldPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = Cosmos2VideoToWorldPipelineWrapper
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"image", "video"})
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
supports_dduf = False
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)
transformer = CosmosTransformer3DModel(
in_channels=16 + 1,
out_channels=16,
num_attention_heads=2,
attention_head_dim=16,
num_layers=2,
mlp_ratio=2,
text_embed_dim=32,
adaln_lora_dim=4,
max_size=(4, 32, 32),
patch_size=(1, 2, 2),
rope_scale=(2.0, 1.0, 1.0),
concat_padding_mask=True,
extra_pos_embed_type="learnable",
)
torch.manual_seed(0)
vae = AutoencoderKLWan(
base_dim=3,
z_dim=16,
dim_mult=[1, 1, 1, 1],
num_res_blocks=1,
temperal_downsample=[False, True, True],
)
torch.manual_seed(0)
scheduler = FlowMatchEulerDiscreteScheduler(use_karras_sigmas=True)
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
components = {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
# We cannot run the Cosmos Guardrail for fast tests due to the large model size
"safety_checker": DummyCosmosSafetyChecker(),
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
image_height = 32
image_width = 32
image = PIL.Image.new("RGB", (image_width, image_height))
inputs = {
"image": image,
"prompt": "dance monkey",
"negative_prompt": "bad quality",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 3.0,
"height": image_height,
"width": image_width,
"num_frames": 9,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
self.assertEqual(generated_video.shape, (9, 3, 32, 32))
expected_video = torch.randn(9, 3, 32, 32)
max_diff = np.abs(generated_video - expected_video).max()
self.assertLessEqual(max_diff, 1e10)
def test_components_function(self):
init_components = self.get_dummy_components()
init_components = {k: v for k, v in init_components.items() if not isinstance(v, (str, int, float))}
pipe = self.pipeline_class(**init_components)
self.assertTrue(hasattr(pipe, "components"))
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
def test_callback_inputs(self):
sig = inspect.signature(self.pipeline_class.__call__)
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
has_callback_step_end = "callback_on_step_end" in sig.parameters
if not (has_callback_tensor_inputs and has_callback_step_end):
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
self.assertTrue(
hasattr(pipe, "_callback_tensor_inputs"),
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
)
def callback_inputs_subset(pipe, i, t, callback_kwargs):
# iterate over callback args
for tensor_name, tensor_value in callback_kwargs.items():
# check that we're only passing in allowed tensor inputs
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
def callback_inputs_all(pipe, i, t, callback_kwargs):
for tensor_name in pipe._callback_tensor_inputs:
assert tensor_name in callback_kwargs
# iterate over callback args
for tensor_name, tensor_value in callback_kwargs.items():
# check that we're only passing in allowed tensor inputs
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
inputs = self.get_dummy_inputs(torch_device)
# Test passing in a subset
inputs["callback_on_step_end"] = callback_inputs_subset
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
output = pipe(**inputs)[0]
# Test passing in a everything
inputs["callback_on_step_end"] = callback_inputs_all
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
output = pipe(**inputs)[0]
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
is_last = i == (pipe.num_timesteps - 1)
if is_last:
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
return callback_kwargs
inputs["callback_on_step_end"] = callback_inputs_change_tensor
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
output = pipe(**inputs)[0]
assert output.abs().sum() < 1e10
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-2)
def test_attention_slicing_forward_pass(
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
):
if not self.test_attention_slicing:
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
output_without_slicing = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=1)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing1 = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=2)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing2 = pipe(**inputs)[0]
if test_max_difference:
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
self.assertLess(
max(max_diff1, max_diff2),
expected_max_diff,
"Attention slicing should not affect the inference results",
)
def test_vae_tiling(self, expected_diff_max: float = 0.2):
generator_device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to("cpu")
pipe.set_progress_bar_config(disable=None)
# Without tiling
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_without_tiling = pipe(**inputs)[0]
# With tiling
pipe.vae.enable_tiling(
tile_sample_min_height=96,
tile_sample_min_width=96,
tile_sample_stride_height=64,
tile_sample_stride_width=64,
)
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_with_tiling = pipe(**inputs)[0]
self.assertLess(
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
expected_diff_max,
"VAE tiling should not affect the inference results",
)
def test_save_load_optional_components(self, expected_max_difference=1e-4):
self.pipeline_class._optional_components.remove("safety_checker")
super().test_save_load_optional_components(expected_max_difference=expected_max_difference)
self.pipeline_class._optional_components.append("safety_checker")
def test_serialization_with_variants(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
model_components = [
component_name
for component_name, component in pipe.components.items()
if isinstance(component, torch.nn.Module)
]
model_components.remove("safety_checker")
variant = "fp16"
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, variant=variant, safe_serialization=False)
with open(f"{tmpdir}/model_index.json", "r") as f:
config = json.load(f)
for subfolder in os.listdir(tmpdir):
if not os.path.isfile(subfolder) and subfolder in model_components:
folder_path = os.path.join(tmpdir, subfolder)
is_folder = os.path.isdir(folder_path) and subfolder in config
assert is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path))
def test_torch_dtype_dict(self):
components = self.get_dummy_components()
if not components:
self.skipTest("No dummy components defined.")
pipe = self.pipeline_class(**components)
specified_key = next(iter(components.keys()))
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
pipe.save_pretrained(tmpdirname, safe_serialization=False)
torch_dtype_dict = {specified_key: torch.bfloat16, "default": torch.float16}
loaded_pipe = self.pipeline_class.from_pretrained(
tmpdirname, safety_checker=DummyCosmosSafetyChecker(), torch_dtype=torch_dtype_dict
)
for name, component in loaded_pipe.components.items():
if name == "safety_checker":
continue
if isinstance(component, torch.nn.Module) and hasattr(component, "dtype"):
expected_dtype = torch_dtype_dict.get(name, torch_dtype_dict.get("default", torch.float32))
self.assertEqual(
component.dtype,
expected_dtype,
f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}",
)
@unittest.skip(
"The pipeline should not be runnable without a safety checker. The test creates a pipeline without passing in "
"a safety checker, which makes the pipeline default to the actual Cosmos Guardrail. The Cosmos Guardrail is "
"too large and slow to run on CI."
)
def test_encode_prompt_works_in_isolation(self):
pass