Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c73c00610e |
@@ -429,7 +429,7 @@
|
||||
- local: api/pipelines/ledits_pp
|
||||
title: LEDITS++
|
||||
- local: api/pipelines/ltx_video
|
||||
title: LTXVideo
|
||||
title: LTX
|
||||
- local: api/pipelines/lumina
|
||||
title: Lumina-T2X
|
||||
- local: api/pipelines/marigold
|
||||
|
||||
@@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
|
||||
```python
|
||||
from diffusers import AutoencoderKLHunyuanVideo
|
||||
|
||||
vae = AutoencoderKLHunyuanVideo.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder="vae", torch_dtype=torch.float16)
|
||||
vae = AutoencoderKLHunyuanVideo.from_pretrained("tencent/HunyuanVideo", torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
## AutoencoderKLHunyuanVideo
|
||||
|
||||
@@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
|
||||
```python
|
||||
from diffusers import AutoencoderKLLTXVideo
|
||||
|
||||
vae = AutoencoderKLLTXVideo.from_pretrained("Lightricks/LTX-Video", subfolder="vae", torch_dtype=torch.float32).to("cuda")
|
||||
vae = AutoencoderKLLTXVideo.from_pretrained("TODO/TODO", subfolder="vae", torch_dtype=torch.float32).to("cuda")
|
||||
```
|
||||
|
||||
## AutoencoderKLLTXVideo
|
||||
|
||||
@@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
|
||||
```python
|
||||
from diffusers import HunyuanVideoTransformer3DModel
|
||||
|
||||
transformer = HunyuanVideoTransformer3DModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
transformer = HunyuanVideoTransformer3DModel.from_pretrained("tencent/HunyuanVideo", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## HunyuanVideoTransformer3DModel
|
||||
|
||||
@@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
|
||||
```python
|
||||
from diffusers import LTXVideoTransformer3DModel
|
||||
|
||||
transformer = LTXVideoTransformer3DModel.from_pretrained("Lightricks/LTX-Video", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
|
||||
transformer = LTXVideoTransformer3DModel.from_pretrained("TODO/TODO", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
|
||||
```
|
||||
|
||||
## LTXVideoTransformer3DModel
|
||||
|
||||
@@ -22,7 +22,7 @@ The model can be loaded with the following code snippet.
|
||||
```python
|
||||
from diffusers import SanaTransformer2DModel
|
||||
|
||||
transformer = SanaTransformer2DModel.from_pretrained("Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
transformer = SanaTransformer2DModel.from_pretrained("Efficient-Large-Model/Sana_1600M_1024px_diffusers", subfolder="transformer", torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
## SanaTransformer2DModel
|
||||
|
||||
@@ -29,7 +29,7 @@ Recommendations for inference:
|
||||
- Transformer should be in `torch.bfloat16`.
|
||||
- VAE should be in `torch.float16`.
|
||||
- `num_frames` should be of the form `4 * k + 1`, for example `49` or `129`.
|
||||
- For smaller resolution videos, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution images, try higher values (between `7.0` and `12.0`). The default value is `7.0` for HunyuanVideo.
|
||||
- For smaller resolution images, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution images, try higher values (between `7.0` and `12.0`). The default value is `7.0` for HunyuanVideo.
|
||||
- For more information about supported resolutions and other details, please refer to the original repository [here](https://github.com/Tencent/HunyuanVideo/).
|
||||
|
||||
## HunyuanVideoPipeline
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# LTX Video
|
||||
# LTX
|
||||
|
||||
[LTX Video](https://huggingface.co/Lightricks/LTX-Video) is the first DiT-based video generation model capable of generating high-quality videos in real-time. It produces 24 FPS videos at a 768x512 resolution faster than they can be watched. Trained on a large-scale dataset of diverse videos, the model generates high-resolution videos with realistic and varied content. We provide a model for both text-to-video as well as image + text-to-video usecases.
|
||||
|
||||
@@ -22,24 +22,14 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
|
||||
|
||||
</Tip>
|
||||
|
||||
Available models:
|
||||
|
||||
| Model name | Recommended dtype |
|
||||
|:-------------:|:-----------------:|
|
||||
| [`LTX Video 0.9.0`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.safetensors) | `torch.bfloat16` |
|
||||
| [`LTX Video 0.9.1`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) | `torch.bfloat16` |
|
||||
|
||||
Note: The recommended dtype is for the transformer component. The VAE and text encoders can be either `torch.float32`, `torch.bfloat16` or `torch.float16` but the recommended dtype is `torch.bfloat16` as used in the original repository.
|
||||
|
||||
## Loading Single Files
|
||||
|
||||
Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`]. We recommend using `from_single_file` for the Lightricks series of models, as they plan to release multiple models in the future in the single file format.
|
||||
Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`].
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import AutoencoderKLLTXVideo, LTXImageToVideoPipeline, LTXVideoTransformer3DModel
|
||||
|
||||
# `single_file_url` could also be https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.1.safetensors
|
||||
single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.safetensors"
|
||||
transformer = LTXVideoTransformer3DModel.from_single_file(
|
||||
single_file_url, torch_dtype=torch.bfloat16
|
||||
@@ -109,34 +99,6 @@ export_to_video(video, "output_gguf_ltx.mp4", fps=24)
|
||||
|
||||
Make sure to read the [documentation on GGUF](../../quantization/gguf) to learn more about our GGUF support.
|
||||
|
||||
<!-- TODO(aryan): Update this when official weights are supported -->
|
||||
|
||||
Loading and running inference with [LTX Video 0.9.1](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) weights.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import LTXPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
pipe = LTXPipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.1-diffusers", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
|
||||
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
|
||||
|
||||
video = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
width=768,
|
||||
height=512,
|
||||
num_frames=161,
|
||||
decode_timestep=0.03,
|
||||
decode_noise_scale=0.025,
|
||||
num_inference_steps=50,
|
||||
).frames[0]
|
||||
export_to_video(video, "output.mp4", fps=24)
|
||||
```
|
||||
|
||||
Refer to [this section](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox#memory-optimization) to learn more about optimizing memory consumption.
|
||||
|
||||
## LTXPipeline
|
||||
|
||||
@@ -32,9 +32,9 @@ Available models:
|
||||
|
||||
| Model | Recommended dtype |
|
||||
|:-----:|:-----------------:|
|
||||
| [`Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers) | `torch.bfloat16` |
|
||||
| [`Efficient-Large-Model/Sana_1600M_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_diffusers) | `torch.float16` |
|
||||
| [`Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers) | `torch.float16` |
|
||||
| [`Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers) | `torch.bfloat16` |
|
||||
| [`Efficient-Large-Model/Sana_1600M_512px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_diffusers) | `torch.float16` |
|
||||
| [`Efficient-Large-Model/Sana_1600M_512px_MultiLing_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_MultiLing_diffusers) | `torch.float16` |
|
||||
| [`Efficient-Large-Model/Sana_600M_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_600M_1024px_diffusers) | `torch.float16` |
|
||||
|
||||
@@ -27,7 +27,7 @@ The example below only quantizes the weights to int8.
|
||||
```python
|
||||
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
|
||||
|
||||
model_id = "black-forest-labs/FLUX.1-dev"
|
||||
model_id = "black-forest-labs/Flux.1-Dev"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
quantization_config = TorchAoConfig("int8wo")
|
||||
@@ -45,9 +45,7 @@ pipe = FluxPipeline.from_pretrained(
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = "A cat holding a sign that says hello world"
|
||||
image = pipe(
|
||||
prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
|
||||
).images[0]
|
||||
image = pipe(prompt, num_inference_steps=28, guidance_scale=0.0).images[0]
|
||||
image.save("output.png")
|
||||
```
|
||||
|
||||
|
||||
@@ -74,7 +74,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -73,7 +73,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -43,7 +43,8 @@ from diffusers.utils import BaseOutput, check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
|
||||
class MarigoldDepthOutput(BaseOutput):
|
||||
"""
|
||||
|
||||
@@ -73,7 +73,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -78,7 +78,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
|
||||
@@ -59,7 +59,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
|
||||
@@ -63,7 +63,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -1,206 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import safetensors
|
||||
|
||||
|
||||
sys.path.append("..")
|
||||
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
class DreamBoothLoRASANA(ExamplesTestsAccelerate):
|
||||
instance_data_dir = "docs/source/en/imgs"
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-sana-pipe"
|
||||
script_path = "examples/dreambooth/train_dreambooth_lora_sana.py"
|
||||
transformer_layer_type = "transformer_blocks.0.attn1.to_k"
|
||||
|
||||
def test_dreambooth_lora_sana(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--resolution 32
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--max_sequence_length 16
|
||||
""".split()
|
||||
|
||||
test_args.extend(["--instance_prompt", ""])
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names.
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_latent_caching(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--resolution 32
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--cache_latents
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--max_sequence_length 16
|
||||
""".split()
|
||||
|
||||
test_args.extend(["--instance_prompt", ""])
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names.
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_layers(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--resolution 32
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--cache_latents
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lora_layers {self.transformer_layer_type}
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--max_sequence_length 16
|
||||
""".split()
|
||||
|
||||
test_args.extend(["--instance_prompt", ""])
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names. In this test, we only params of
|
||||
# `self.transformer_layer_type` should be in the state dict.
|
||||
starts_with_transformer = all(self.transformer_layer_type in key for key in lora_state_dict)
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_sana_checkpointing_checkpoints_total_limit(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--resolution=32
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=6
|
||||
--checkpoints_total_limit=2
|
||||
--checkpointing_steps=2
|
||||
--max_sequence_length 16
|
||||
""".split()
|
||||
|
||||
test_args.extend(["--instance_prompt", ""])
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_dreambooth_lora_sana_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--resolution=32
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=4
|
||||
--checkpointing_steps=2
|
||||
--max_sequence_length 166
|
||||
""".split()
|
||||
|
||||
test_args.extend(["--instance_prompt", ""])
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
|
||||
|
||||
resume_run_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--resolution=32
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=8
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-4
|
||||
--checkpoints_total_limit=2
|
||||
--max_sequence_length 16
|
||||
""".split()
|
||||
|
||||
resume_run_args.extend(["--instance_prompt", ""])
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
|
||||
@@ -63,7 +63,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
# Cache compiled models across invocations of this script.
|
||||
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
||||
|
||||
@@ -65,7 +65,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -70,7 +70,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -70,7 +70,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -943,7 +943,7 @@ def main(args):
|
||||
|
||||
# Load scheduler and models
|
||||
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="scheduler", revision=args.revision
|
||||
args.pretrained_model_name_or_path, subfolder="scheduler"
|
||||
)
|
||||
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
||||
text_encoder = Gemma2Model.from_pretrained(
|
||||
@@ -964,6 +964,15 @@ def main(args):
|
||||
vae.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
|
||||
# Initialize a text encoding pipeline and keep it to CPU for now.
|
||||
text_encoding_pipeline = SanaPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=None,
|
||||
transformer=None,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
|
||||
# as these weights are only used for inference, keeping weights in full precision is not required.
|
||||
weight_dtype = torch.float32
|
||||
@@ -984,15 +993,6 @@ def main(args):
|
||||
# because Gemma2 is particularly suited for bfloat16.
|
||||
text_encoder.to(dtype=torch.bfloat16)
|
||||
|
||||
# Initialize a text encoding pipeline and keep it to CPU for now.
|
||||
text_encoding_pipeline = SanaPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=None,
|
||||
transformer=None,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
transformer.enable_gradient_checkpointing()
|
||||
|
||||
@@ -1182,7 +1182,6 @@ def main(args):
|
||||
)
|
||||
if args.offload:
|
||||
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
|
||||
prompt_embeds = prompt_embeds.to(transformer.dtype)
|
||||
return prompt_embeds, prompt_attention_mask
|
||||
|
||||
# If no type of tuning is done on the text_encoder and custom instance prompts are NOT
|
||||
@@ -1217,7 +1216,7 @@ def main(args):
|
||||
vae_config_scaling_factor = vae.config.scaling_factor
|
||||
if args.cache_latents:
|
||||
latents_cache = []
|
||||
vae = vae.to(accelerator.device)
|
||||
vae = vae.to("cuda")
|
||||
for batch in tqdm(train_dataloader, desc="Caching latents"):
|
||||
with torch.no_grad():
|
||||
batch["pixel_values"] = batch["pixel_values"].to(
|
||||
|
||||
@@ -72,7 +72,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
|
||||
@@ -55,7 +55,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
|
||||
@@ -81,7 +81,7 @@ else:
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ else:
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -76,7 +76,7 @@ else:
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from safetensors.torch import load_file
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
|
||||
@@ -23,9 +21,7 @@ TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"vae": remove_keys_,
|
||||
}
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
VAE_KEYS_RENAME_DICT = {
|
||||
# decoder
|
||||
@@ -58,31 +54,10 @@ VAE_KEYS_RENAME_DICT = {
|
||||
"per_channel_statistics.std-of-means": "latents_std",
|
||||
}
|
||||
|
||||
VAE_091_RENAME_DICT = {
|
||||
# decoder
|
||||
"up_blocks.0": "mid_block",
|
||||
"up_blocks.1": "up_blocks.0.upsamplers.0",
|
||||
"up_blocks.2": "up_blocks.0",
|
||||
"up_blocks.3": "up_blocks.1.upsamplers.0",
|
||||
"up_blocks.4": "up_blocks.1",
|
||||
"up_blocks.5": "up_blocks.2.upsamplers.0",
|
||||
"up_blocks.6": "up_blocks.2",
|
||||
"up_blocks.7": "up_blocks.3.upsamplers.0",
|
||||
"up_blocks.8": "up_blocks.3",
|
||||
# common
|
||||
"last_time_embedder": "time_embedder",
|
||||
"last_scale_shift_table": "scale_shift_table",
|
||||
}
|
||||
|
||||
VAE_SPECIAL_KEYS_REMAP = {
|
||||
"per_channel_statistics.channel": remove_keys_,
|
||||
"per_channel_statistics.mean-of-means": remove_keys_,
|
||||
"per_channel_statistics.mean-of-stds": remove_keys_,
|
||||
"model.diffusion_model": remove_keys_,
|
||||
}
|
||||
|
||||
VAE_091_SPECIAL_KEYS_REMAP = {
|
||||
"timestep_scale_multiplier": remove_keys_,
|
||||
}
|
||||
|
||||
|
||||
@@ -105,16 +80,13 @@ def convert_transformer(
|
||||
ckpt_path: str,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
PREFIX_KEY = "model.diffusion_model."
|
||||
PREFIX_KEY = ""
|
||||
|
||||
original_state_dict = get_state_dict(load_file(ckpt_path))
|
||||
with init_empty_weights():
|
||||
transformer = LTXVideoTransformer3DModel()
|
||||
transformer = LTXVideoTransformer3DModel().to(dtype=dtype)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
if new_key.startswith(PREFIX_KEY):
|
||||
new_key = key[len(PREFIX_KEY) :]
|
||||
new_key = key[len(PREFIX_KEY) :]
|
||||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
@@ -125,21 +97,16 @@ def convert_transformer(
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
transformer.load_state_dict(original_state_dict, strict=True)
|
||||
return transformer
|
||||
|
||||
|
||||
def convert_vae(ckpt_path: str, config, dtype: torch.dtype):
|
||||
PREFIX_KEY = "vae."
|
||||
|
||||
def convert_vae(ckpt_path: str, dtype: torch.dtype):
|
||||
original_state_dict = get_state_dict(load_file(ckpt_path))
|
||||
with init_empty_weights():
|
||||
vae = AutoencoderKLLTXVideo(**config)
|
||||
vae = AutoencoderKLLTXVideo().to(dtype=dtype)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
if new_key.startswith(PREFIX_KEY):
|
||||
new_key = key[len(PREFIX_KEY) :]
|
||||
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
@@ -150,60 +117,10 @@ def convert_vae(ckpt_path: str, config, dtype: torch.dtype):
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
vae.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
vae.load_state_dict(original_state_dict, strict=True)
|
||||
return vae
|
||||
|
||||
|
||||
def get_vae_config(version: str) -> Dict[str, Any]:
|
||||
if version == "0.9.0":
|
||||
config = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (128, 256, 512, 512),
|
||||
"decoder_block_out_channels": (128, 256, 512, 512),
|
||||
"layers_per_block": (4, 3, 3, 3, 4),
|
||||
"decoder_layers_per_block": (4, 3, 3, 3, 4),
|
||||
"spatio_temporal_scaling": (True, True, True, False),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True, False),
|
||||
"decoder_inject_noise": (False, False, False, False, False),
|
||||
"upsample_residual": (False, False, False, False),
|
||||
"upsample_factor": (1, 1, 1, 1),
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
"scaling_factor": 1.0,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
"timestep_conditioning": False,
|
||||
}
|
||||
elif version == "0.9.1":
|
||||
config = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (128, 256, 512, 512),
|
||||
"decoder_block_out_channels": (256, 512, 1024),
|
||||
"layers_per_block": (4, 3, 3, 3, 4),
|
||||
"decoder_layers_per_block": (5, 6, 7, 8),
|
||||
"spatio_temporal_scaling": (True, True, True, False),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True),
|
||||
"decoder_inject_noise": (True, True, True, False),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": True,
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
"scaling_factor": 1.0,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
}
|
||||
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
|
||||
VAE_SPECIAL_KEYS_REMAP.update(VAE_091_SPECIAL_KEYS_REMAP)
|
||||
return config
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
@@ -222,9 +139,6 @@ def get_args():
|
||||
parser.add_argument("--save_pipeline", action="store_true")
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
|
||||
parser.add_argument(
|
||||
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1"], help="Version of the LTX model"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -247,7 +161,6 @@ if __name__ == "__main__":
|
||||
transformer = None
|
||||
dtype = DTYPE_MAPPING[args.dtype]
|
||||
variant = VARIANT_MAPPING[args.dtype]
|
||||
output_path = Path(args.output_path)
|
||||
|
||||
if args.save_pipeline:
|
||||
assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None
|
||||
@@ -256,14 +169,13 @@ if __name__ == "__main__":
|
||||
transformer: LTXVideoTransformer3DModel = convert_transformer(args.transformer_ckpt_path, dtype)
|
||||
if not args.save_pipeline:
|
||||
transformer.save_pretrained(
|
||||
output_path / "transformer", safe_serialization=True, max_shard_size="5GB", variant=variant
|
||||
args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant
|
||||
)
|
||||
|
||||
if args.vae_ckpt_path is not None:
|
||||
config = get_vae_config(args.version)
|
||||
vae: AutoencoderKLLTXVideo = convert_vae(args.vae_ckpt_path, config, dtype)
|
||||
vae: AutoencoderKLLTXVideo = convert_vae(args.vae_ckpt_path, dtype)
|
||||
if not args.save_pipeline:
|
||||
vae.save_pretrained(output_path / "vae", safe_serialization=True, max_shard_size="5GB", variant=variant)
|
||||
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant)
|
||||
|
||||
if args.save_pipeline:
|
||||
text_encoder_id = "google/t5-v1_1-xxl"
|
||||
|
||||
@@ -88,18 +88,13 @@ def main(args):
|
||||
# y norm
|
||||
converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
|
||||
|
||||
# scheduler
|
||||
flow_shift = 3.0
|
||||
|
||||
# model config
|
||||
if args.model_type == "SanaMS_1600M_P1_D20":
|
||||
layer_num = 20
|
||||
elif args.model_type == "SanaMS_600M_P1_D28":
|
||||
layer_num = 28
|
||||
else:
|
||||
raise ValueError(f"{args.model_type} is not supported.")
|
||||
# Positional embedding interpolation scale.
|
||||
interpolation_scale = {512: None, 1024: None, 2048: 1.0}
|
||||
|
||||
for depth in range(layer_num):
|
||||
# Transformer blocks.
|
||||
@@ -181,7 +176,6 @@ def main(args):
|
||||
patch_size=1,
|
||||
norm_elementwise_affine=False,
|
||||
norm_eps=1e-6,
|
||||
interpolation_scale=interpolation_scale[args.image_size],
|
||||
)
|
||||
|
||||
if is_accelerate_available():
|
||||
|
||||
@@ -254,7 +254,7 @@ version_range_max = max(sys.version_info[1], 10) + 1
|
||||
|
||||
setup(
|
||||
name="diffusers",
|
||||
version="0.32.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="0.32.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
description="State-of-the-art diffusion in PyTorch and JAX.",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
__version__ = "0.32.0"
|
||||
__version__ = "0.32.0.dev0"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
@@ -28,7 +28,6 @@ from .single_file_utils import (
|
||||
convert_autoencoder_dc_checkpoint_to_diffusers,
|
||||
convert_controlnet_checkpoint,
|
||||
convert_flux_transformer_checkpoint_to_diffusers,
|
||||
convert_hunyuan_video_transformer_to_diffusers,
|
||||
convert_ldm_unet_checkpoint,
|
||||
convert_ldm_vae_checkpoint,
|
||||
convert_ltx_transformer_checkpoint_to_diffusers,
|
||||
@@ -102,10 +101,6 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"HunyuanVideoTransformer3DModel": {
|
||||
"checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -225,7 +220,6 @@ class FromOriginalModelMixin:
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
config_revision = kwargs.pop("config_revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
quantization_config = kwargs.pop("quantization_config", None)
|
||||
device = kwargs.pop("device", None)
|
||||
@@ -303,7 +297,7 @@ class FromOriginalModelMixin:
|
||||
subfolder=subfolder,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=config_revision,
|
||||
revision=revision,
|
||||
)
|
||||
expected_kwargs, optional_kwargs = cls._get_signature_keys(cls)
|
||||
|
||||
|
||||
@@ -108,7 +108,6 @@ CHECKPOINT_KEY_NAMES = {
|
||||
"autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias",
|
||||
"autoencoder-dc-sana": "encoder.project_in.conv.bias",
|
||||
"mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"],
|
||||
"hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
|
||||
}
|
||||
|
||||
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
@@ -157,14 +156,12 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"},
|
||||
"flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
|
||||
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
|
||||
"ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"},
|
||||
"ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"},
|
||||
"ltx-video": {"pretrained_model_name_or_path": "Lightricks/LTX-Video"},
|
||||
"autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"},
|
||||
"autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"},
|
||||
"autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
|
||||
"autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"},
|
||||
"mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"},
|
||||
"hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"},
|
||||
}
|
||||
|
||||
# Use to configure model sample size when original config is provided
|
||||
@@ -606,10 +603,7 @@ def infer_diffusers_model_type(checkpoint):
|
||||
model_type = "flux-schnell"
|
||||
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx-video"]):
|
||||
if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in checkpoint:
|
||||
model_type = "ltx-video-0.9.1"
|
||||
else:
|
||||
model_type = "ltx-video"
|
||||
model_type = "ltx-video"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["autoencoder-dc"] in checkpoint:
|
||||
encoder_key = "encoder.project_in.conv.conv.bias"
|
||||
@@ -630,9 +624,6 @@ def infer_diffusers_model_type(checkpoint):
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["mochi-1-preview"]):
|
||||
model_type = "mochi-1-preview"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint:
|
||||
model_type = "hunyuan-video"
|
||||
|
||||
else:
|
||||
model_type = "v1"
|
||||
|
||||
@@ -2342,32 +2333,12 @@ def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
"per_channel_statistics.std-of-means": "latents_std",
|
||||
}
|
||||
|
||||
VAE_091_RENAME_DICT = {
|
||||
# decoder
|
||||
"up_blocks.0": "mid_block",
|
||||
"up_blocks.1": "up_blocks.0.upsamplers.0",
|
||||
"up_blocks.2": "up_blocks.0",
|
||||
"up_blocks.3": "up_blocks.1.upsamplers.0",
|
||||
"up_blocks.4": "up_blocks.1",
|
||||
"up_blocks.5": "up_blocks.2.upsamplers.0",
|
||||
"up_blocks.6": "up_blocks.2",
|
||||
"up_blocks.7": "up_blocks.3.upsamplers.0",
|
||||
"up_blocks.8": "up_blocks.3",
|
||||
# common
|
||||
"last_time_embedder": "time_embedder",
|
||||
"last_scale_shift_table": "scale_shift_table",
|
||||
}
|
||||
|
||||
VAE_SPECIAL_KEYS_REMAP = {
|
||||
"per_channel_statistics.channel": remove_keys_,
|
||||
"per_channel_statistics.mean-of-means": remove_keys_,
|
||||
"per_channel_statistics.mean-of-stds": remove_keys_,
|
||||
"timestep_scale_multiplier": remove_keys_,
|
||||
}
|
||||
|
||||
if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict:
|
||||
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
|
||||
|
||||
for key in list(converted_state_dict.keys()):
|
||||
new_key = key
|
||||
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
|
||||
@@ -2551,133 +2522,3 @@ def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
new_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies")
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def convert_hunyuan_video_transformer_to_diffusers(checkpoint, **kwargs):
|
||||
def remap_norm_scale_shift_(key, state_dict):
|
||||
weight = state_dict.pop(key)
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight
|
||||
|
||||
def remap_txt_in_(key, state_dict):
|
||||
def rename_key(key):
|
||||
new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
|
||||
new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
|
||||
new_key = new_key.replace("txt_in", "context_embedder")
|
||||
new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
|
||||
new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
|
||||
new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
|
||||
new_key = new_key.replace("mlp", "ff")
|
||||
return new_key
|
||||
|
||||
if "self_attn_qkv" in key:
|
||||
weight = state_dict.pop(key)
|
||||
to_q, to_k, to_v = weight.chunk(3, dim=0)
|
||||
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
|
||||
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
|
||||
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
|
||||
else:
|
||||
state_dict[rename_key(key)] = state_dict.pop(key)
|
||||
|
||||
def remap_img_attn_qkv_(key, state_dict):
|
||||
weight = state_dict.pop(key)
|
||||
to_q, to_k, to_v = weight.chunk(3, dim=0)
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v
|
||||
|
||||
def remap_txt_attn_qkv_(key, state_dict):
|
||||
weight = state_dict.pop(key)
|
||||
to_q, to_k, to_v = weight.chunk(3, dim=0)
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v
|
||||
|
||||
def remap_single_transformer_blocks_(key, state_dict):
|
||||
hidden_size = 3072
|
||||
|
||||
if "linear1.weight" in key:
|
||||
linear1_weight = state_dict.pop(key)
|
||||
split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
|
||||
q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
|
||||
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.weight")
|
||||
state_dict[f"{new_key}.attn.to_q.weight"] = q
|
||||
state_dict[f"{new_key}.attn.to_k.weight"] = k
|
||||
state_dict[f"{new_key}.attn.to_v.weight"] = v
|
||||
state_dict[f"{new_key}.proj_mlp.weight"] = mlp
|
||||
|
||||
elif "linear1.bias" in key:
|
||||
linear1_bias = state_dict.pop(key)
|
||||
split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
|
||||
q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
|
||||
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.bias")
|
||||
state_dict[f"{new_key}.attn.to_q.bias"] = q_bias
|
||||
state_dict[f"{new_key}.attn.to_k.bias"] = k_bias
|
||||
state_dict[f"{new_key}.attn.to_v.bias"] = v_bias
|
||||
state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias
|
||||
|
||||
else:
|
||||
new_key = key.replace("single_blocks", "single_transformer_blocks")
|
||||
new_key = new_key.replace("linear2", "proj_out")
|
||||
new_key = new_key.replace("q_norm", "attn.norm_q")
|
||||
new_key = new_key.replace("k_norm", "attn.norm_k")
|
||||
state_dict[new_key] = state_dict.pop(key)
|
||||
|
||||
TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"img_in": "x_embedder",
|
||||
"time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
|
||||
"time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
|
||||
"guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
|
||||
"guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
|
||||
"vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
|
||||
"vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
|
||||
"double_blocks": "transformer_blocks",
|
||||
"img_attn_q_norm": "attn.norm_q",
|
||||
"img_attn_k_norm": "attn.norm_k",
|
||||
"img_attn_proj": "attn.to_out.0",
|
||||
"txt_attn_q_norm": "attn.norm_added_q",
|
||||
"txt_attn_k_norm": "attn.norm_added_k",
|
||||
"txt_attn_proj": "attn.to_add_out",
|
||||
"img_mod.linear": "norm1.linear",
|
||||
"img_norm1": "norm1.norm",
|
||||
"img_norm2": "norm2",
|
||||
"img_mlp": "ff",
|
||||
"txt_mod.linear": "norm1_context.linear",
|
||||
"txt_norm1": "norm1.norm",
|
||||
"txt_norm2": "norm2_context",
|
||||
"txt_mlp": "ff_context",
|
||||
"self_attn_proj": "attn.to_out.0",
|
||||
"modulation.linear": "norm.linear",
|
||||
"pre_norm": "norm.norm",
|
||||
"final_layer.norm_final": "norm_out.norm",
|
||||
"final_layer.linear": "proj_out",
|
||||
"fc1": "net.0.proj",
|
||||
"fc2": "net.2",
|
||||
"input_embedder": "proj_in",
|
||||
}
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"txt_in": remap_txt_in_,
|
||||
"img_attn_qkv": remap_img_attn_qkv_,
|
||||
"txt_attn_qkv": remap_txt_attn_qkv_,
|
||||
"single_blocks": remap_single_transformer_blocks_,
|
||||
"final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
|
||||
}
|
||||
|
||||
def update_state_dict_(state_dict, old_key, new_key):
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
for key in list(checkpoint.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_(checkpoint, key, new_key)
|
||||
|
||||
for key in list(checkpoint.keys()):
|
||||
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, checkpoint)
|
||||
|
||||
return checkpoint
|
||||
|
||||
@@ -177,5 +177,3 @@ class FluxTransformer2DLoadersMixin:
|
||||
|
||||
self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
|
||||
self.config.encoder_hid_dim_type = "ip_image_proj"
|
||||
|
||||
self.to(dtype=self.dtype, device=self.device)
|
||||
|
||||
@@ -22,14 +22,13 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..activations import get_activation
|
||||
from ..embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import RMSNorm
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution
|
||||
|
||||
|
||||
class LTXVideoCausalConv3d(nn.Module):
|
||||
class LTXCausalConv3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
@@ -80,9 +79,9 @@ class LTXVideoCausalConv3d(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXVideoResnetBlock3d(nn.Module):
|
||||
class LTXResnetBlock3d(nn.Module):
|
||||
r"""
|
||||
A 3D ResNet block used in the LTXVideo model.
|
||||
A 3D ResNet block used in the LTX model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`):
|
||||
@@ -110,9 +109,7 @@ class LTXVideoResnetBlock3d(nn.Module):
|
||||
elementwise_affine: bool = False,
|
||||
non_linearity: str = "swish",
|
||||
is_causal: bool = True,
|
||||
inject_noise: bool = False,
|
||||
timestep_conditioning: bool = False,
|
||||
) -> None:
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
out_channels = out_channels or in_channels
|
||||
@@ -120,13 +117,13 @@ class LTXVideoResnetBlock3d(nn.Module):
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
self.norm1 = RMSNorm(in_channels, eps=1e-8, elementwise_affine=elementwise_affine)
|
||||
self.conv1 = LTXVideoCausalConv3d(
|
||||
self.conv1 = LTXCausalConv3d(
|
||||
in_channels=in_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal
|
||||
)
|
||||
|
||||
self.norm2 = RMSNorm(out_channels, eps=1e-8, elementwise_affine=elementwise_affine)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.conv2 = LTXVideoCausalConv3d(
|
||||
self.conv2 = LTXCausalConv3d(
|
||||
in_channels=out_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal
|
||||
)
|
||||
|
||||
@@ -134,58 +131,22 @@ class LTXVideoResnetBlock3d(nn.Module):
|
||||
self.conv_shortcut = None
|
||||
if in_channels != out_channels:
|
||||
self.norm3 = nn.LayerNorm(in_channels, eps=eps, elementwise_affine=True, bias=True)
|
||||
self.conv_shortcut = LTXVideoCausalConv3d(
|
||||
self.conv_shortcut = LTXCausalConv3d(
|
||||
in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, is_causal=is_causal
|
||||
)
|
||||
|
||||
self.per_channel_scale1 = None
|
||||
self.per_channel_scale2 = None
|
||||
if inject_noise:
|
||||
self.per_channel_scale1 = nn.Parameter(torch.zeros(in_channels, 1, 1))
|
||||
self.per_channel_scale2 = nn.Parameter(torch.zeros(in_channels, 1, 1))
|
||||
|
||||
self.scale_shift_table = None
|
||||
if timestep_conditioning:
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(4, in_channels) / in_channels**0.5)
|
||||
|
||||
def forward(
|
||||
self, inputs: torch.Tensor, temb: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None
|
||||
) -> torch.Tensor:
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = inputs
|
||||
|
||||
hidden_states = self.norm1(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
|
||||
if self.scale_shift_table is not None:
|
||||
temb = temb.unflatten(1, (4, -1)) + self.scale_shift_table[None, ..., None, None, None]
|
||||
shift_1, scale_1, shift_2, scale_2 = temb.unbind(dim=1)
|
||||
hidden_states = hidden_states * (1 + scale_1) + shift_1
|
||||
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
|
||||
if self.per_channel_scale1 is not None:
|
||||
spatial_shape = hidden_states.shape[-2:]
|
||||
spatial_noise = torch.randn(
|
||||
spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype
|
||||
)[None]
|
||||
hidden_states = hidden_states + (spatial_noise * self.per_channel_scale1)[None, :, None, ...]
|
||||
|
||||
hidden_states = self.norm2(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
|
||||
if self.scale_shift_table is not None:
|
||||
hidden_states = hidden_states * (1 + scale_2) + shift_2
|
||||
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
if self.per_channel_scale2 is not None:
|
||||
spatial_shape = hidden_states.shape[-2:]
|
||||
spatial_noise = torch.randn(
|
||||
spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype
|
||||
)[None]
|
||||
hidden_states = hidden_states + (spatial_noise * self.per_channel_scale2)[None, :, None, ...]
|
||||
|
||||
if self.norm3 is not None:
|
||||
inputs = self.norm3(inputs.movedim(1, -1)).movedim(-1, 1)
|
||||
|
||||
@@ -196,24 +157,20 @@ class LTXVideoResnetBlock3d(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXVideoUpsampler3d(nn.Module):
|
||||
class LTXUpsampler3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
is_causal: bool = True,
|
||||
residual: bool = False,
|
||||
upscale_factor: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
|
||||
self.residual = residual
|
||||
self.upscale_factor = upscale_factor
|
||||
|
||||
out_channels = (in_channels * stride[0] * stride[1] * stride[2]) // upscale_factor
|
||||
out_channels = in_channels * stride[0] * stride[1] * stride[2]
|
||||
|
||||
self.conv = LTXVideoCausalConv3d(
|
||||
self.conv = LTXCausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
@@ -224,15 +181,6 @@ class LTXVideoUpsampler3d(nn.Module):
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
|
||||
if self.residual:
|
||||
residual = hidden_states.reshape(
|
||||
batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width
|
||||
)
|
||||
residual = residual.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
repeats = (self.stride[0] * self.stride[1] * self.stride[2]) // self.upscale_factor
|
||||
residual = residual.repeat(1, repeats, 1, 1, 1)
|
||||
residual = residual[:, :, self.stride[0] - 1 :]
|
||||
|
||||
hidden_states = self.conv(hidden_states)
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width
|
||||
@@ -240,15 +188,12 @@ class LTXVideoUpsampler3d(nn.Module):
|
||||
hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
hidden_states = hidden_states[:, :, self.stride[0] - 1 :]
|
||||
|
||||
if self.residual:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXVideoDownBlock3D(nn.Module):
|
||||
class LTXDownBlock3D(nn.Module):
|
||||
r"""
|
||||
Down block used in the LTXVideo model.
|
||||
Down block used in the LTX model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`):
|
||||
@@ -290,7 +235,7 @@ class LTXVideoDownBlock3D(nn.Module):
|
||||
resnets = []
|
||||
for _ in range(num_layers):
|
||||
resnets.append(
|
||||
LTXVideoResnetBlock3d(
|
||||
LTXResnetBlock3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
dropout=dropout,
|
||||
@@ -305,7 +250,7 @@ class LTXVideoDownBlock3D(nn.Module):
|
||||
if spatio_temporal_scale:
|
||||
self.downsamplers = nn.ModuleList(
|
||||
[
|
||||
LTXVideoCausalConv3d(
|
||||
LTXCausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
kernel_size=3,
|
||||
@@ -317,7 +262,7 @@ class LTXVideoDownBlock3D(nn.Module):
|
||||
|
||||
self.conv_out = None
|
||||
if in_channels != out_channels:
|
||||
self.conv_out = LTXVideoResnetBlock3d(
|
||||
self.conv_out = LTXResnetBlock3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
dropout=dropout,
|
||||
@@ -328,12 +273,7 @@ class LTXVideoDownBlock3D(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> torch.Tensor:
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
r"""Forward method of the `LTXDownBlock3D` class."""
|
||||
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
@@ -345,26 +285,24 @@ class LTXVideoDownBlock3D(nn.Module):
|
||||
|
||||
return create_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, generator
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, generator)
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
if self.conv_out is not None:
|
||||
hidden_states = self.conv_out(hidden_states, temb, generator)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d
|
||||
class LTXVideoMidBlock3d(nn.Module):
|
||||
class LTXMidBlock3d(nn.Module):
|
||||
r"""
|
||||
A middle block used in the LTXVideo model.
|
||||
A middle block used in the LTX model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`):
|
||||
@@ -391,51 +329,28 @@ class LTXVideoMidBlock3d(nn.Module):
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_act_fn: str = "swish",
|
||||
is_causal: bool = True,
|
||||
inject_noise: bool = False,
|
||||
timestep_conditioning: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.time_embedder = None
|
||||
if timestep_conditioning:
|
||||
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0)
|
||||
|
||||
resnets = []
|
||||
for _ in range(num_layers):
|
||||
resnets.append(
|
||||
LTXVideoResnetBlock3d(
|
||||
LTXResnetBlock3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
dropout=dropout,
|
||||
eps=resnet_eps,
|
||||
non_linearity=resnet_act_fn,
|
||||
is_causal=is_causal,
|
||||
inject_noise=inject_noise,
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
)
|
||||
)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> torch.Tensor:
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
r"""Forward method of the `LTXMidBlock3D` class."""
|
||||
|
||||
if self.time_embedder is not None:
|
||||
temb = self.time_embedder(
|
||||
timestep=temb.flatten(),
|
||||
resolution=None,
|
||||
aspect_ratio=None,
|
||||
batch_size=hidden_states.size(0),
|
||||
hidden_dtype=hidden_states.dtype,
|
||||
)
|
||||
temb = temb.view(hidden_states.size(0), -1, 1, 1, 1)
|
||||
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
@@ -445,18 +360,16 @@ class LTXVideoMidBlock3d(nn.Module):
|
||||
|
||||
return create_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, generator
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, generator)
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXVideoUpBlock3d(nn.Module):
|
||||
class LTXUpBlock3d(nn.Module):
|
||||
r"""
|
||||
Up block used in the LTXVideo model.
|
||||
Up block used in the LTX model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`):
|
||||
@@ -490,82 +403,45 @@ class LTXVideoUpBlock3d(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
spatio_temporal_scale: bool = True,
|
||||
is_causal: bool = True,
|
||||
inject_noise: bool = False,
|
||||
timestep_conditioning: bool = False,
|
||||
upsample_residual: bool = False,
|
||||
upscale_factor: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
out_channels = out_channels or in_channels
|
||||
|
||||
self.time_embedder = None
|
||||
if timestep_conditioning:
|
||||
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0)
|
||||
|
||||
self.conv_in = None
|
||||
if in_channels != out_channels:
|
||||
self.conv_in = LTXVideoResnetBlock3d(
|
||||
self.conv_in = LTXResnetBlock3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
dropout=dropout,
|
||||
eps=resnet_eps,
|
||||
non_linearity=resnet_act_fn,
|
||||
is_causal=is_causal,
|
||||
inject_noise=inject_noise,
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
)
|
||||
|
||||
self.upsamplers = None
|
||||
if spatio_temporal_scale:
|
||||
self.upsamplers = nn.ModuleList(
|
||||
[
|
||||
LTXVideoUpsampler3d(
|
||||
out_channels * upscale_factor,
|
||||
stride=(2, 2, 2),
|
||||
is_causal=is_causal,
|
||||
residual=upsample_residual,
|
||||
upscale_factor=upscale_factor,
|
||||
)
|
||||
]
|
||||
)
|
||||
self.upsamplers = nn.ModuleList([LTXUpsampler3d(out_channels, stride=(2, 2, 2), is_causal=is_causal)])
|
||||
|
||||
resnets = []
|
||||
for _ in range(num_layers):
|
||||
resnets.append(
|
||||
LTXVideoResnetBlock3d(
|
||||
LTXResnetBlock3d(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
dropout=dropout,
|
||||
eps=resnet_eps,
|
||||
non_linearity=resnet_act_fn,
|
||||
is_causal=is_causal,
|
||||
inject_noise=inject_noise,
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
)
|
||||
)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> torch.Tensor:
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if self.conv_in is not None:
|
||||
hidden_states = self.conv_in(hidden_states, temb, generator)
|
||||
|
||||
if self.time_embedder is not None:
|
||||
temb = self.time_embedder(
|
||||
timestep=temb.flatten(),
|
||||
resolution=None,
|
||||
aspect_ratio=None,
|
||||
batch_size=hidden_states.size(0),
|
||||
hidden_dtype=hidden_states.dtype,
|
||||
)
|
||||
temb = temb.view(hidden_states.size(0), -1, 1, 1, 1)
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
@@ -580,18 +456,16 @@ class LTXVideoUpBlock3d(nn.Module):
|
||||
|
||||
return create_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, generator
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, generator)
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXVideoEncoder3d(nn.Module):
|
||||
class LTXEncoder3d(nn.Module):
|
||||
r"""
|
||||
The `LTXVideoEncoder3d` layer of a variational autoencoder that encodes input video samples to its latent
|
||||
The `LTXEncoder3D` layer of a variational autoencoder that encodes input video samples to its latent
|
||||
representation.
|
||||
|
||||
Args:
|
||||
@@ -635,7 +509,7 @@ class LTXVideoEncoder3d(nn.Module):
|
||||
|
||||
output_channel = block_out_channels[0]
|
||||
|
||||
self.conv_in = LTXVideoCausalConv3d(
|
||||
self.conv_in = LTXCausalConv3d(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=output_channel,
|
||||
kernel_size=3,
|
||||
@@ -650,7 +524,7 @@ class LTXVideoEncoder3d(nn.Module):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i]
|
||||
|
||||
down_block = LTXVideoDownBlock3D(
|
||||
down_block = LTXDownBlock3D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
num_layers=layers_per_block[i],
|
||||
@@ -662,7 +536,7 @@ class LTXVideoEncoder3d(nn.Module):
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid block
|
||||
self.mid_block = LTXVideoMidBlock3d(
|
||||
self.mid_block = LTXMidBlock3d(
|
||||
in_channels=output_channel,
|
||||
num_layers=layers_per_block[-1],
|
||||
resnet_eps=resnet_norm_eps,
|
||||
@@ -672,14 +546,14 @@ class LTXVideoEncoder3d(nn.Module):
|
||||
# out
|
||||
self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = LTXVideoCausalConv3d(
|
||||
self.conv_out = LTXCausalConv3d(
|
||||
in_channels=output_channel, out_channels=out_channels + 1, kernel_size=3, stride=1, is_causal=is_causal
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
r"""The forward method of the `LTXVideoEncoder3d` class."""
|
||||
r"""The forward method of the `LTXEncoder3D` class."""
|
||||
|
||||
p = self.patch_size
|
||||
p_t = self.patch_size_t
|
||||
@@ -725,10 +599,9 @@ class LTXVideoEncoder3d(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXVideoDecoder3d(nn.Module):
|
||||
class LTXDecoder3d(nn.Module):
|
||||
r"""
|
||||
The `LTXVideoDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output
|
||||
sample.
|
||||
The `LTXDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output sample.
|
||||
|
||||
Args:
|
||||
in_channels (`int`, defaults to 128):
|
||||
@@ -749,8 +622,6 @@ class LTXVideoDecoder3d(nn.Module):
|
||||
Epsilon value for ResNet normalization layers.
|
||||
is_causal (`bool`, defaults to `False`):
|
||||
Whether this layer behaves causally (future frames depend only on past frames) or not.
|
||||
timestep_conditioning (`bool`, defaults to `False`):
|
||||
Whether to condition the model on timesteps.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -764,10 +635,6 @@ class LTXVideoDecoder3d(nn.Module):
|
||||
patch_size_t: int = 1,
|
||||
resnet_norm_eps: float = 1e-6,
|
||||
is_causal: bool = False,
|
||||
inject_noise: Tuple[bool, ...] = (False, False, False, False),
|
||||
timestep_conditioning: bool = False,
|
||||
upsample_residual: Tuple[bool, ...] = (False, False, False, False),
|
||||
upsample_factor: Tuple[bool, ...] = (1, 1, 1, 1),
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -778,42 +645,30 @@ class LTXVideoDecoder3d(nn.Module):
|
||||
block_out_channels = tuple(reversed(block_out_channels))
|
||||
spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling))
|
||||
layers_per_block = tuple(reversed(layers_per_block))
|
||||
inject_noise = tuple(reversed(inject_noise))
|
||||
upsample_residual = tuple(reversed(upsample_residual))
|
||||
upsample_factor = tuple(reversed(upsample_factor))
|
||||
output_channel = block_out_channels[0]
|
||||
|
||||
self.conv_in = LTXVideoCausalConv3d(
|
||||
self.conv_in = LTXCausalConv3d(
|
||||
in_channels=in_channels, out_channels=output_channel, kernel_size=3, stride=1, is_causal=is_causal
|
||||
)
|
||||
|
||||
self.mid_block = LTXVideoMidBlock3d(
|
||||
in_channels=output_channel,
|
||||
num_layers=layers_per_block[0],
|
||||
resnet_eps=resnet_norm_eps,
|
||||
is_causal=is_causal,
|
||||
inject_noise=inject_noise[0],
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
self.mid_block = LTXMidBlock3d(
|
||||
in_channels=output_channel, num_layers=layers_per_block[0], resnet_eps=resnet_norm_eps, is_causal=is_causal
|
||||
)
|
||||
|
||||
# up blocks
|
||||
num_block_out_channels = len(block_out_channels)
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
for i in range(num_block_out_channels):
|
||||
input_channel = output_channel // upsample_factor[i]
|
||||
output_channel = block_out_channels[i] // upsample_factor[i]
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
|
||||
up_block = LTXVideoUpBlock3d(
|
||||
up_block = LTXUpBlock3d(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
num_layers=layers_per_block[i + 1],
|
||||
resnet_eps=resnet_norm_eps,
|
||||
spatio_temporal_scale=spatio_temporal_scaling[i],
|
||||
is_causal=is_causal,
|
||||
inject_noise=inject_noise[i + 1],
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
upsample_residual=upsample_residual[i],
|
||||
upscale_factor=upsample_factor[i],
|
||||
)
|
||||
|
||||
self.up_blocks.append(up_block)
|
||||
@@ -821,20 +676,13 @@ class LTXVideoDecoder3d(nn.Module):
|
||||
# out
|
||||
self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = LTXVideoCausalConv3d(
|
||||
self.conv_out = LTXCausalConv3d(
|
||||
in_channels=output_channel, out_channels=self.out_channels, kernel_size=3, stride=1, is_causal=is_causal
|
||||
)
|
||||
|
||||
# timestep embedding
|
||||
self.time_embedder = None
|
||||
self.scale_shift_table = None
|
||||
if timestep_conditioning:
|
||||
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
@@ -845,33 +693,17 @@ class LTXVideoDecoder3d(nn.Module):
|
||||
|
||||
return create_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block), hidden_states, temb
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), hidden_states)
|
||||
|
||||
for up_block in self.up_blocks:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), hidden_states)
|
||||
else:
|
||||
hidden_states = self.mid_block(hidden_states, temb)
|
||||
hidden_states = self.mid_block(hidden_states)
|
||||
|
||||
for up_block in self.up_blocks:
|
||||
hidden_states = up_block(hidden_states, temb)
|
||||
hidden_states = up_block(hidden_states)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
|
||||
if self.time_embedder is not None:
|
||||
temb = self.time_embedder(
|
||||
timestep=temb.flatten(),
|
||||
resolution=None,
|
||||
aspect_ratio=None,
|
||||
batch_size=hidden_states.size(0),
|
||||
hidden_dtype=hidden_states.dtype,
|
||||
)
|
||||
temb = temb.view(hidden_states.size(0), -1, 1, 1, 1).unflatten(1, (2, -1))
|
||||
temb = temb + self.scale_shift_table[None, ..., None, None, None]
|
||||
shift, scale = temb.unbind(dim=1)
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
@@ -934,15 +766,8 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
out_channels: int = 3,
|
||||
latent_channels: int = 128,
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
|
||||
decoder_layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
|
||||
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
|
||||
decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
|
||||
decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False, False),
|
||||
upsample_residual: Tuple[bool, ...] = (False, False, False, False),
|
||||
upsample_factor: Tuple[int, ...] = (1, 1, 1, 1),
|
||||
timestep_conditioning: bool = False,
|
||||
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
|
||||
patch_size: int = 4,
|
||||
patch_size_t: int = 1,
|
||||
resnet_norm_eps: float = 1e-6,
|
||||
@@ -952,7 +777,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.encoder = LTXVideoEncoder3d(
|
||||
self.encoder = LTXEncoder3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=latent_channels,
|
||||
block_out_channels=block_out_channels,
|
||||
@@ -963,20 +788,16 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
resnet_norm_eps=resnet_norm_eps,
|
||||
is_causal=encoder_causal,
|
||||
)
|
||||
self.decoder = LTXVideoDecoder3d(
|
||||
self.decoder = LTXDecoder3d(
|
||||
in_channels=latent_channels,
|
||||
out_channels=out_channels,
|
||||
block_out_channels=decoder_block_out_channels,
|
||||
spatio_temporal_scaling=decoder_spatio_temporal_scaling,
|
||||
layers_per_block=decoder_layers_per_block,
|
||||
block_out_channels=block_out_channels,
|
||||
spatio_temporal_scaling=spatio_temporal_scaling,
|
||||
layers_per_block=layers_per_block,
|
||||
patch_size=patch_size,
|
||||
patch_size_t=patch_size_t,
|
||||
resnet_norm_eps=resnet_norm_eps,
|
||||
is_causal=decoder_causal,
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
inject_noise=decoder_inject_noise,
|
||||
upsample_residual=upsample_residual,
|
||||
upsample_factor=upsample_factor,
|
||||
)
|
||||
|
||||
latents_mean = torch.zeros((latent_channels,), requires_grad=False)
|
||||
@@ -1016,7 +837,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
self.tile_sample_stride_width = 448
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (LTXVideoEncoder3d, LTXVideoDecoder3d)):
|
||||
if isinstance(module, (LTXEncoder3d, LTXDecoder3d)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def enable_tiling(
|
||||
@@ -1115,15 +936,13 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
return (posterior,)
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(
|
||||
self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
batch_size, num_channels, num_frames, height, width = z.shape
|
||||
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
||||
tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
|
||||
|
||||
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
|
||||
return self.tiled_decode(z, temb, return_dict=return_dict)
|
||||
return self.tiled_decode(z, return_dict=return_dict)
|
||||
|
||||
if self.use_framewise_decoding:
|
||||
# TODO(aryan): requires investigation
|
||||
@@ -1133,7 +952,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
|
||||
)
|
||||
else:
|
||||
dec = self.decoder(z, temb)
|
||||
dec = self.decoder(z)
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
@@ -1141,9 +960,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(
|
||||
self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
"""
|
||||
Decode a batch of images.
|
||||
|
||||
@@ -1158,15 +975,10 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
returned.
|
||||
"""
|
||||
if self.use_slicing and z.shape[0] > 1:
|
||||
if temb is not None:
|
||||
decoded_slices = [
|
||||
self._decode(z_slice, t_slice).sample for z_slice, t_slice in (z.split(1), temb.split(1))
|
||||
]
|
||||
else:
|
||||
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
||||
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
else:
|
||||
decoded = self._decode(z, temb).sample
|
||||
decoded = self._decode(z).sample
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
@@ -1248,9 +1060,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
|
||||
return enc
|
||||
|
||||
def tiled_decode(
|
||||
self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
r"""
|
||||
Decode a batch of images using a tiled decoder.
|
||||
|
||||
@@ -1291,9 +1101,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
|
||||
)
|
||||
else:
|
||||
time = self.decoder(
|
||||
z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb
|
||||
)
|
||||
time = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width])
|
||||
|
||||
row.append(time)
|
||||
rows.append(row)
|
||||
@@ -1321,7 +1129,6 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
@@ -1332,7 +1139,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
z = posterior.sample(generator=generator)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z, temb)
|
||||
dec = self.decode(z)
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
return dec
|
||||
|
||||
@@ -748,10 +748,10 @@ class CogVideoXPatchEmbed(nn.Module):
|
||||
pos_embedding = self._get_positional_embeddings(
|
||||
height, width, pre_time_compression_frames, device=embeds.device
|
||||
)
|
||||
pos_embedding = pos_embedding.to(dtype=embeds.dtype)
|
||||
else:
|
||||
pos_embedding = self.pos_embedding
|
||||
|
||||
pos_embedding = pos_embedding.to(dtype=embeds.dtype)
|
||||
embeds = embeds + pos_embedding
|
||||
|
||||
return embeds
|
||||
|
||||
@@ -228,7 +228,7 @@ def load_model_dict_into_meta(
|
||||
else:
|
||||
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
|
||||
raise ValueError(
|
||||
f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name].shape}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
|
||||
f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
|
||||
)
|
||||
|
||||
if is_quantized and (
|
||||
|
||||
@@ -242,7 +242,6 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
patch_size: int = 1,
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-6,
|
||||
interpolation_scale: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -250,14 +249,14 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
# 1. Patch Embedding
|
||||
interpolation_scale = interpolation_scale if interpolation_scale is not None else max(sample_size // 64, 1)
|
||||
self.patch_embed = PatchEmbed(
|
||||
height=sample_size,
|
||||
width=sample_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dim=inner_dim,
|
||||
interpolation_scale=interpolation_scale,
|
||||
interpolation_scale=None,
|
||||
pos_embed_type=None,
|
||||
)
|
||||
|
||||
# 2. Additional condition embeddings
|
||||
|
||||
@@ -18,8 +18,6 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from diffusers.loaders import FromOriginalModelMixin
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
@@ -502,7 +500,7 @@ class HunyuanVideoTransformerBlock(nn.Module):
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
r"""
|
||||
A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ from ..normalization import AdaLayerNormSingle, RMSNorm
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class LTXVideoAttentionProcessor2_0:
|
||||
class LTXAttentionProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
||||
used in the LTX model. It applies a normalization layer and rotary embedding on the query and key vector.
|
||||
@@ -44,7 +44,7 @@ class LTXVideoAttentionProcessor2_0:
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"LTXVideoAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
"LTXAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
@@ -92,7 +92,7 @@ class LTXVideoAttentionProcessor2_0:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXVideoRotaryPosEmbed(nn.Module):
|
||||
class LTXRotaryPosEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
@@ -164,7 +164,7 @@ class LTXVideoRotaryPosEmbed(nn.Module):
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class LTXVideoTransformerBlock(nn.Module):
|
||||
class LTXTransformerBlock(nn.Module):
|
||||
r"""
|
||||
Transformer block used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
|
||||
|
||||
@@ -208,7 +208,7 @@ class LTXVideoTransformerBlock(nn.Module):
|
||||
cross_attention_dim=None,
|
||||
out_bias=attention_out_bias,
|
||||
qk_norm=qk_norm,
|
||||
processor=LTXVideoAttentionProcessor2_0(),
|
||||
processor=LTXAttentionProcessor2_0(),
|
||||
)
|
||||
|
||||
self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
@@ -221,7 +221,7 @@ class LTXVideoTransformerBlock(nn.Module):
|
||||
bias=attention_bias,
|
||||
out_bias=attention_out_bias,
|
||||
qk_norm=qk_norm,
|
||||
processor=LTXVideoAttentionProcessor2_0(),
|
||||
processor=LTXAttentionProcessor2_0(),
|
||||
)
|
||||
|
||||
self.ff = FeedForward(dim, activation_fn=activation_fn)
|
||||
@@ -327,7 +327,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
|
||||
|
||||
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
||||
|
||||
self.rope = LTXVideoRotaryPosEmbed(
|
||||
self.rope = LTXRotaryPosEmbed(
|
||||
dim=inner_dim,
|
||||
base_num_frames=20,
|
||||
base_height=2048,
|
||||
@@ -339,7 +339,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
LTXVideoTransformerBlock(
|
||||
LTXTransformerBlock(
|
||||
dim=inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
|
||||
@@ -39,7 +39,7 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
|
||||
>>> from diffusers.utils import export_to_video
|
||||
|
||||
>>> model_id = "hunyuanvideo-community/HunyuanVideo"
|
||||
>>> model_id = "tencent/HunyuanVideo"
|
||||
>>> transformer = HunyuanVideoTransformer3DModel.from_pretrained(
|
||||
... model_id, subfolder="transformer", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
|
||||
@@ -193,15 +193,15 @@ class KandinskyCombinedPipeline(DiffusionPipeline):
|
||||
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
|
||||
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗
|
||||
Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a
|
||||
GPU only when their specific submodule's `forward` method is called. Offloading happens on a submodule basis.
|
||||
Memory savings are higher than using `enable_model_cpu_offload`, but performance is lower.
|
||||
"""
|
||||
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
|
||||
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
|
||||
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
|
||||
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
|
||||
|
||||
def progress_bar(self, iterable=None, total=None):
|
||||
self.prior_pipe.progress_bar(iterable=iterable, total=total)
|
||||
@@ -411,7 +411,7 @@ class KandinskyImg2ImgCombinedPipeline(DiffusionPipeline):
|
||||
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
|
||||
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
||||
@@ -419,8 +419,8 @@ class KandinskyImg2ImgCombinedPipeline(DiffusionPipeline):
|
||||
Note that offloading happens on a submodule basis. Memory savings are higher than with
|
||||
`enable_model_cpu_offload`, but performance is lower.
|
||||
"""
|
||||
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
|
||||
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
|
||||
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
|
||||
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
|
||||
|
||||
def progress_bar(self, iterable=None, total=None):
|
||||
self.prior_pipe.progress_bar(iterable=iterable, total=total)
|
||||
@@ -652,7 +652,7 @@ class KandinskyInpaintCombinedPipeline(DiffusionPipeline):
|
||||
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
|
||||
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
||||
@@ -660,8 +660,8 @@ class KandinskyInpaintCombinedPipeline(DiffusionPipeline):
|
||||
Note that offloading happens on a submodule basis. Memory savings are higher than with
|
||||
`enable_model_cpu_offload`, but performance is lower.
|
||||
"""
|
||||
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
|
||||
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
|
||||
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
|
||||
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
|
||||
|
||||
def progress_bar(self, iterable=None, total=None):
|
||||
self.prior_pipe.progress_bar(iterable=iterable, total=total)
|
||||
|
||||
@@ -511,8 +511,6 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
decode_timestep: Union[float, List[float]] = 0.0,
|
||||
decode_noise_scale: Optional[Union[float, List[float]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
@@ -565,10 +563,6 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
|
||||
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
||||
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated attention mask for negative text embeddings.
|
||||
decode_timestep (`float`, defaults to `0.0`):
|
||||
The timestep at which generated video is decoded.
|
||||
decode_noise_scale (`float`, defaults to `None`):
|
||||
The interpolation factor between random noise and denoised latents at the decode timestep.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
@@ -759,25 +753,7 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
latents = latents.to(prompt_embeds.dtype)
|
||||
|
||||
if not self.vae.config.timestep_conditioning:
|
||||
timestep = None
|
||||
else:
|
||||
noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype)
|
||||
if not isinstance(decode_timestep, list):
|
||||
decode_timestep = [decode_timestep] * batch_size
|
||||
if decode_noise_scale is None:
|
||||
decode_noise_scale = decode_timestep
|
||||
elif not isinstance(decode_noise_scale, list):
|
||||
decode_noise_scale = [decode_noise_scale] * batch_size
|
||||
|
||||
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
|
||||
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
|
||||
:, None, None, None, None
|
||||
]
|
||||
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
|
||||
|
||||
video = self.vae.decode(latents, timestep, return_dict=False)[0]
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
|
||||
@@ -571,8 +571,6 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
decode_timestep: Union[float, List[float]] = 0.0,
|
||||
decode_noise_scale: Optional[Union[float, List[float]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
@@ -627,10 +625,6 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
|
||||
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
||||
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated attention mask for negative text embeddings.
|
||||
decode_timestep (`float`, defaults to `0.0`):
|
||||
The timestep at which generated video is decoded.
|
||||
decode_noise_scale (`float`, defaults to `None`):
|
||||
The interpolation factor between random noise and denoised latents at the decode timestep.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
@@ -855,25 +849,7 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
latents = latents.to(prompt_embeds.dtype)
|
||||
|
||||
if not self.vae.config.timestep_conditioning:
|
||||
timestep = None
|
||||
else:
|
||||
noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype)
|
||||
if not isinstance(decode_timestep, list):
|
||||
decode_timestep = [decode_timestep] * batch_size
|
||||
if decode_noise_scale is None:
|
||||
decode_noise_scale = decode_timestep
|
||||
elif not isinstance(decode_noise_scale, list):
|
||||
decode_noise_scale = [decode_noise_scale] * batch_size
|
||||
|
||||
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
|
||||
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
|
||||
:, None, None, None, None
|
||||
]
|
||||
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
|
||||
|
||||
video = self.vae.decode(latents, timestep, return_dict=False)[0]
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
|
||||
@@ -59,13 +59,13 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> from diffusers import SanaPAGPipeline
|
||||
|
||||
>>> pipe = SanaPAGPipeline.from_pretrained(
|
||||
... "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
|
||||
... "Efficient-Large-Model/Sana_1600M_1024px_diffusers",
|
||||
... pag_applied_layers=["transformer_blocks.8"],
|
||||
... torch_dtype=torch.float32,
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
>>> pipe.text_encoder.to(torch.bfloat16)
|
||||
>>> pipe.transformer = pipe.transformer.to(torch.bfloat16)
|
||||
>>> pipe.transformer = pipe.transformer.to(torch.float16)
|
||||
|
||||
>>> image = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"')[0]
|
||||
>>> image[0].save("output.png")
|
||||
|
||||
@@ -62,11 +62,11 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> from diffusers import SanaPipeline
|
||||
|
||||
>>> pipe = SanaPipeline.from_pretrained(
|
||||
... "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", torch_dtype=torch.float32
|
||||
... "Efficient-Large-Model/Sana_1600M_1024px_diffusers", torch_dtype=torch.float32
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
>>> pipe.text_encoder.to(torch.bfloat16)
|
||||
>>> pipe.transformer = pipe.transformer.to(torch.bfloat16)
|
||||
>>> pipe.transformer = pipe.transformer.to(torch.float16)
|
||||
|
||||
>>> image = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"')[0]
|
||||
>>> image[0].save("output.png")
|
||||
|
||||
@@ -93,11 +93,6 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
|
||||
raise ImportError(
|
||||
"Loading a TorchAO quantized model requires the torchao library. Please install with `pip install torchao`"
|
||||
)
|
||||
torchao_version = version.parse(importlib.metadata.version("torch"))
|
||||
if torchao_version < version.parse("0.7.0"):
|
||||
raise RuntimeError(
|
||||
f"The minimum required version of `torchao` is 0.7.0, but the current version is {torchao_version}. Please upgrade with `pip install -U torchao`."
|
||||
)
|
||||
|
||||
self.offload = False
|
||||
|
||||
|
||||
@@ -490,11 +490,11 @@ def require_gguf_version_greater_or_equal(gguf_version):
|
||||
return decorator
|
||||
|
||||
|
||||
def require_torchao_version_greater_or_equal(torchao_version):
|
||||
def require_torchao_version_greater(torchao_version):
|
||||
def decorator(test_case):
|
||||
correct_torchao_version = is_torchao_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("torchao")).base_version
|
||||
) >= version.parse(torchao_version)
|
||||
) > version.parse(torchao_version)
|
||||
return unittest.skipUnless(
|
||||
correct_torchao_version, f"Test requires torchao with version greater than {torchao_version}."
|
||||
)(test_case)
|
||||
|
||||
@@ -15,6 +15,8 @@
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
@@ -27,13 +29,16 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
floats_tensor,
|
||||
is_torch_version,
|
||||
require_peft_backend,
|
||||
skip_mps,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
|
||||
from utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
@@ -118,6 +123,41 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
@skip_mps
|
||||
@pytest.mark.xfail(
|
||||
condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
|
||||
reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
|
||||
strict=True,
|
||||
)
|
||||
def test_lora_fuse_nan(self):
|
||||
for scheduler_cls in self.scheduler_classes:
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
|
||||
# corrupt one LoRA weight with `inf` values
|
||||
with torch.no_grad():
|
||||
pipe.transformer.transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf")
|
||||
|
||||
# with `safe_fusing=True` we should see an Error
|
||||
with self.assertRaises(ValueError):
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
|
||||
|
||||
# without we should not see an error, but every image will be black
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
|
||||
|
||||
out = pipe(
|
||||
"test", num_inference_steps=2, max_sequence_length=inputs["max_sequence_length"], output_type="np"
|
||||
)[0]
|
||||
|
||||
self.assertTrue(np.isnan(out).all())
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
|
||||
|
||||
|
||||
@@ -36,6 +36,7 @@ from diffusers.utils.testing_utils import (
|
||||
numpy_cosine_similarity_distance,
|
||||
require_big_gpu_with_torch_cuda,
|
||||
require_peft_backend,
|
||||
require_peft_version_greater,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
@@ -330,8 +331,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
}
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
|
||||
lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
@@ -340,32 +340,85 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
|
||||
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
|
||||
|
||||
# Testing opposite direction where the LoRA params are zero-padded.
|
||||
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
@require_peft_version_greater("0.13.2")
|
||||
def test_lora_B_bias(self):
|
||||
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
dummy_lora_A = torch.nn.Linear(1, rank, bias=False)
|
||||
dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
|
||||
lora_state_dict = {
|
||||
"transformer.x_embedder.lora_A.weight": dummy_lora_A.weight,
|
||||
"transformer.x_embedder.lora_B.weight": dummy_lora_B.weight,
|
||||
|
||||
# keep track of the bias values of the base layers to perform checks later.
|
||||
bias_values = {}
|
||||
for name, module in pipe.transformer.named_modules():
|
||||
if any(k in name for k in ["to_q", "to_k", "to_v", "to_out.0"]):
|
||||
if module.bias is not None:
|
||||
bias_values[name] = module.bias.data.clone()
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
denoiser_lora_config.lora_bias = False
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
lora_bias_false_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
pipe.delete_adapters("adapter-1")
|
||||
|
||||
denoiser_lora_config.lora_bias = True
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
lora_bias_true_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertFalse(np.allclose(original_output, lora_bias_false_output, atol=1e-3, rtol=1e-3))
|
||||
self.assertFalse(np.allclose(original_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))
|
||||
self.assertFalse(np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))
|
||||
|
||||
# for now this is flux control lora specific but can be generalized later and added to ./utils.py
|
||||
def test_correct_lora_configs_with_different_ranks(self):
|
||||
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
lora_output_same_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
pipe.transformer.delete_adapters("adapter-1")
|
||||
|
||||
# change the rank_pattern
|
||||
updated_rank = denoiser_lora_config.r * 2
|
||||
denoiser_lora_config.rank_pattern = {"single_transformer_blocks.0.attn.to_k": updated_rank}
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
assert pipe.transformer.peft_config["adapter-1"].rank_pattern == {
|
||||
"single_transformer_blocks.0.attn.to_k": updated_rank
|
||||
}
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
lora_output_diff_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(not np.allclose(original_output, lora_output_same_rank, atol=1e-3, rtol=1e-3))
|
||||
self.assertTrue(not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3))
|
||||
pipe.transformer.delete_adapters("adapter-1")
|
||||
|
||||
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
|
||||
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
|
||||
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
|
||||
self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out)
|
||||
# similarly change the alpha_pattern
|
||||
updated_alpha = denoiser_lora_config.lora_alpha * 2
|
||||
denoiser_lora_config.alpha_pattern = {"single_transformer_blocks.0.attn.to_k": updated_alpha}
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
assert pipe.transformer.peft_config["adapter-1"].alpha_pattern == {
|
||||
"single_transformer_blocks.0.attn.to_k": updated_alpha
|
||||
}
|
||||
|
||||
def test_normal_lora_with_expanded_lora_raises_error(self):
|
||||
# Test the following situation. Load a regular LoRA (such as the ones trained on Flux.1-Dev). And then
|
||||
# load shape expanded LoRA (such as Control LoRA).
|
||||
lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
|
||||
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
|
||||
|
||||
def test_lora_expanding_shape_with_normal_lora(self):
|
||||
# This test checks if it works when a lora with expanded shapes (like control loras) but
|
||||
# another lora with correct shapes is loaded. The opposite direction isn't supported and is
|
||||
# tested with it.
|
||||
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
|
||||
# Change the transformer config to mimic a real use case.
|
||||
|
||||
@@ -15,6 +15,8 @@
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast
|
||||
|
||||
@@ -26,14 +28,16 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
floats_tensor,
|
||||
is_torch_version,
|
||||
require_peft_backend,
|
||||
skip_mps,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
|
||||
from utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
@@ -140,6 +144,46 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
@pytest.mark.xfail(
|
||||
condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
|
||||
reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
|
||||
strict=True,
|
||||
)
|
||||
def test_lora_fuse_nan(self):
|
||||
for scheduler_cls in self.scheduler_classes:
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
|
||||
# corrupt one LoRA weight with `inf` values
|
||||
with torch.no_grad():
|
||||
pipe.transformer.transformer_blocks[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
|
||||
|
||||
# with `safe_fusing=True` we should see an Error
|
||||
with self.assertRaises(ValueError):
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
|
||||
|
||||
# without we should not see an error, but every image will be black
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
|
||||
|
||||
out = pipe(
|
||||
prompt=inputs["prompt"],
|
||||
height=inputs["height"],
|
||||
width=inputs["width"],
|
||||
num_frames=inputs["num_frames"],
|
||||
num_inference_steps=inputs["num_inference_steps"],
|
||||
max_sequence_length=inputs["max_sequence_length"],
|
||||
output_type="np",
|
||||
)[0]
|
||||
|
||||
self.assertTrue(np.isnan(out).all())
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
|
||||
|
||||
|
||||
@@ -15,6 +15,8 @@
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
@@ -24,12 +26,18 @@ from diffusers import (
|
||||
LTXPipeline,
|
||||
LTXVideoTransformer3DModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend
|
||||
from diffusers.utils.testing_utils import (
|
||||
floats_tensor,
|
||||
is_torch_version,
|
||||
require_peft_backend,
|
||||
skip_mps,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
|
||||
from utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
@@ -52,19 +60,10 @@ class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
}
|
||||
transformer_cls = LTXVideoTransformer3DModel
|
||||
vae_kwargs = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 8,
|
||||
"block_out_channels": (8, 8, 8, 8),
|
||||
"decoder_block_out_channels": (8, 8, 8, 8),
|
||||
"layers_per_block": (1, 1, 1, 1, 1),
|
||||
"decoder_layers_per_block": (1, 1, 1, 1, 1),
|
||||
"spatio_temporal_scaling": (True, True, False, False),
|
||||
"decoder_spatio_temporal_scaling": (True, True, False, False),
|
||||
"decoder_inject_noise": (False, False, False, False, False),
|
||||
"upsample_residual": (False, False, False, False),
|
||||
"upsample_factor": (1, 1, 1, 1),
|
||||
"timestep_conditioning": False,
|
||||
"layers_per_block": (1, 1, 1, 1, 1),
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"encoder_causal": True,
|
||||
@@ -108,6 +107,41 @@ class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
@skip_mps
|
||||
@pytest.mark.xfail(
|
||||
condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
|
||||
reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
|
||||
strict=True,
|
||||
)
|
||||
def test_lora_fuse_nan(self):
|
||||
for scheduler_cls in self.scheduler_classes:
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
|
||||
# corrupt one LoRA weight with `inf` values
|
||||
with torch.no_grad():
|
||||
pipe.transformer.transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf")
|
||||
|
||||
# with `safe_fusing=True` we should see an Error
|
||||
with self.assertRaises(ValueError):
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
|
||||
|
||||
# without we should not see an error, but every image will be black
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
|
||||
|
||||
out = pipe(
|
||||
"test", num_inference_steps=2, max_sequence_length=inputs["max_sequence_length"], output_type="np"
|
||||
)[0]
|
||||
|
||||
self.assertTrue(np.isnan(out).all())
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
|
||||
|
||||
|
||||
@@ -15,20 +15,24 @@
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKLMochi, FlowMatchEulerDiscreteScheduler, MochiPipeline, MochiTransformer3DModel
|
||||
from diffusers.utils.testing_utils import (
|
||||
floats_tensor,
|
||||
is_torch_version,
|
||||
require_peft_backend,
|
||||
skip_mps,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
|
||||
from utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
@@ -99,6 +103,40 @@ class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
@pytest.mark.xfail(
|
||||
condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
|
||||
reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
|
||||
strict=True,
|
||||
)
|
||||
def test_lora_fuse_nan(self):
|
||||
for scheduler_cls in self.scheduler_classes:
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
|
||||
# corrupt one LoRA weight with `inf` values
|
||||
with torch.no_grad():
|
||||
pipe.transformer.transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf")
|
||||
|
||||
# with `safe_fusing=True` we should see an Error
|
||||
with self.assertRaises(ValueError):
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
|
||||
|
||||
# without we should not see an error, but every image will be black
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
|
||||
|
||||
out = pipe(
|
||||
"test", num_inference_steps=2, max_sequence_length=inputs["max_sequence_length"], output_type="np"
|
||||
)[0]
|
||||
|
||||
self.assertTrue(np.isnan(out).all())
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ import sys
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from transformers import Gemma2Model, GemmaTokenizer
|
||||
from transformers import Gemma2ForCausalLM, GemmaTokenizer
|
||||
|
||||
from diffusers import AutoencoderDC, FlowMatchEulerDiscreteScheduler, SanaPipeline, SanaTransformer2DModel
|
||||
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend
|
||||
@@ -73,7 +73,7 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
}
|
||||
vae_cls = AutoencoderDC
|
||||
tokenizer_cls, tokenizer_id = GemmaTokenizer, "hf-internal-testing/dummy-gemma"
|
||||
text_encoder_cls, text_encoder_id = Gemma2Model, "hf-internal-testing/dummy-gemma-for-diffusers"
|
||||
text_encoder_cls, text_encoder_id = Gemma2ForCausalLM, "hf-internal-testing/dummy-gemma-for-diffusers"
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
@@ -105,34 +105,34 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
@unittest.skip("Not supported in SANA.")
|
||||
@unittest.skip("Not supported in Sana.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in SANA.")
|
||||
@unittest.skip("Not supported in Mochi.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in SANA.")
|
||||
@unittest.skip("Not supported in Mochi.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
+2
-112
@@ -1528,7 +1528,7 @@ class PeftLoraLoaderMixinTests:
|
||||
@pytest.mark.xfail(
|
||||
condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
|
||||
reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
|
||||
strict=False,
|
||||
strict=True,
|
||||
)
|
||||
def test_lora_fuse_nan(self):
|
||||
for scheduler_cls in self.scheduler_classes:
|
||||
@@ -1568,7 +1568,7 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
# without we should not see an error, but every image will be black
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
|
||||
out = pipe(**inputs)[0]
|
||||
out = pipe("test", num_inference_steps=2, output_type="np")[0]
|
||||
|
||||
self.assertTrue(np.isnan(out).all())
|
||||
|
||||
@@ -1988,113 +1988,3 @@ class PeftLoraLoaderMixinTests:
|
||||
np.allclose(output_lora_scale_wo_kwargs, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
|
||||
"Loading from saved checkpoints should give same results as set_adapters().",
|
||||
)
|
||||
|
||||
@require_peft_version_greater("0.13.2")
|
||||
def test_lora_B_bias(self):
|
||||
# Currently, this test is only relevant for Flux Control LoRA as we are not
|
||||
# aware of any other LoRA checkpoint that has its `lora_B` biases trained.
|
||||
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# keep track of the bias values of the base layers to perform checks later.
|
||||
bias_values = {}
|
||||
denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer
|
||||
for name, module in denoiser.named_modules():
|
||||
if any(k in name for k in ["to_q", "to_k", "to_v", "to_out.0"]):
|
||||
if module.bias is not None:
|
||||
bias_values[name] = module.bias.data.clone()
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
denoiser_lora_config.lora_bias = False
|
||||
if self.unet_kwargs is not None:
|
||||
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
else:
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
lora_bias_false_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
pipe.delete_adapters("adapter-1")
|
||||
|
||||
denoiser_lora_config.lora_bias = True
|
||||
if self.unet_kwargs is not None:
|
||||
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
else:
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
lora_bias_true_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertFalse(np.allclose(original_output, lora_bias_false_output, atol=1e-3, rtol=1e-3))
|
||||
self.assertFalse(np.allclose(original_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))
|
||||
self.assertFalse(np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))
|
||||
|
||||
def test_correct_lora_configs_with_different_ranks(self):
|
||||
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
if self.unet_kwargs is not None:
|
||||
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
else:
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
|
||||
lora_output_same_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
if self.unet_kwargs is not None:
|
||||
pipe.unet.delete_adapters("adapter-1")
|
||||
else:
|
||||
pipe.transformer.delete_adapters("adapter-1")
|
||||
|
||||
denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer
|
||||
for name, _ in denoiser.named_modules():
|
||||
if "to_k" in name and "attn" in name and "lora" not in name:
|
||||
module_name_to_rank_update = name.replace(".base_layer.", ".")
|
||||
break
|
||||
|
||||
# change the rank_pattern
|
||||
updated_rank = denoiser_lora_config.r * 2
|
||||
denoiser_lora_config.rank_pattern = {module_name_to_rank_update: updated_rank}
|
||||
|
||||
if self.unet_kwargs is not None:
|
||||
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
updated_rank_pattern = pipe.unet.peft_config["adapter-1"].rank_pattern
|
||||
else:
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
updated_rank_pattern = pipe.transformer.peft_config["adapter-1"].rank_pattern
|
||||
|
||||
self.assertTrue(updated_rank_pattern == {module_name_to_rank_update: updated_rank})
|
||||
|
||||
lora_output_diff_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(not np.allclose(original_output, lora_output_same_rank, atol=1e-3, rtol=1e-3))
|
||||
self.assertTrue(not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3))
|
||||
|
||||
if self.unet_kwargs is not None:
|
||||
pipe.unet.delete_adapters("adapter-1")
|
||||
else:
|
||||
pipe.transformer.delete_adapters("adapter-1")
|
||||
|
||||
# similarly change the alpha_pattern
|
||||
updated_alpha = denoiser_lora_config.lora_alpha * 2
|
||||
denoiser_lora_config.alpha_pattern = {module_name_to_rank_update: updated_alpha}
|
||||
if self.unet_kwargs is not None:
|
||||
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
self.assertTrue(
|
||||
pipe.unet.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha}
|
||||
)
|
||||
else:
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
self.assertTrue(
|
||||
pipe.transformer.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha}
|
||||
)
|
||||
|
||||
lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
|
||||
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
|
||||
|
||||
@@ -1,169 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import AutoencoderKLLTXVideo
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKLLTXVideo
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
def get_autoencoder_kl_ltx_video_config(self):
|
||||
return {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 8,
|
||||
"block_out_channels": (8, 8, 8, 8),
|
||||
"decoder_block_out_channels": (8, 8, 8, 8),
|
||||
"layers_per_block": (1, 1, 1, 1, 1),
|
||||
"decoder_layers_per_block": (1, 1, 1, 1, 1),
|
||||
"spatio_temporal_scaling": (True, True, False, False),
|
||||
"decoder_spatio_temporal_scaling": (True, True, False, False),
|
||||
"decoder_inject_noise": (False, False, False, False, False),
|
||||
"upsample_residual": (False, False, False, False),
|
||||
"upsample_factor": (1, 1, 1, 1),
|
||||
"timestep_conditioning": False,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
}
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 2
|
||||
num_frames = 9
|
||||
num_channels = 3
|
||||
sizes = (16, 16)
|
||||
|
||||
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
|
||||
|
||||
return {"sample": image}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 9, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 9, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_autoencoder_kl_ltx_video_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"LTXVideoEncoder3d",
|
||||
"LTXVideoDecoder3d",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoMidBlock3d",
|
||||
"LTXVideoUpBlock3d",
|
||||
}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@unittest.skip("Unsupported test.")
|
||||
def test_outputs_equivalence(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
|
||||
|
||||
class AutoencoderKLLTXVideo091Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKLLTXVideo
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
def get_autoencoder_kl_ltx_video_config(self):
|
||||
return {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 8,
|
||||
"block_out_channels": (8, 8, 8, 8),
|
||||
"decoder_block_out_channels": (16, 32, 64),
|
||||
"layers_per_block": (1, 1, 1, 1),
|
||||
"decoder_layers_per_block": (1, 1, 1, 1),
|
||||
"spatio_temporal_scaling": (True, True, True, False),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True),
|
||||
"decoder_inject_noise": (True, True, True, False),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": True,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
}
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 2
|
||||
num_frames = 9
|
||||
num_channels = 3
|
||||
sizes = (16, 16)
|
||||
|
||||
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
|
||||
timestep = torch.tensor([0.05] * batch_size, device=torch_device)
|
||||
|
||||
return {"sample": image, "temb": timestep}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 9, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 9, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_autoencoder_kl_ltx_video_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"LTXVideoEncoder3d",
|
||||
"LTXVideoDecoder3d",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoMidBlock3d",
|
||||
"LTXVideoUpBlock3d",
|
||||
}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@unittest.skip("Unsupported test.")
|
||||
def test_outputs_equivalence(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
@@ -2,12 +2,10 @@ import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor
|
||||
from diffusers.utils.testing_utils import torch_device
|
||||
|
||||
|
||||
class AttnAddedKVProcessorTests(unittest.TestCase):
|
||||
@@ -81,15 +79,6 @@ class AttnAddedKVProcessorTests(unittest.TestCase):
|
||||
|
||||
|
||||
class DeprecatedAttentionBlockTests(unittest.TestCase):
|
||||
@pytest.fixture(scope="session")
|
||||
def is_dist_enabled(pytestconfig):
|
||||
return pytestconfig.getoption("dist") == "loadfile"
|
||||
|
||||
@pytest.mark.xfail(
|
||||
condition=torch.device(torch_device).type == "cuda" and is_dist_enabled,
|
||||
reason="Test currently fails on our GPU CI because of `loadfile`. Note that it only fails when the tests are distributed from `pytest ... tests/models`. If the tests are run individually, even with `loadfile` it won't fail.",
|
||||
strict=True,
|
||||
)
|
||||
def test_conversion_when_using_device_map(self):
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
|
||||
|
||||
@@ -22,14 +22,12 @@ import traceback
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import requests_mock
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from accelerate.utils.modeling import _get_proper_dtype, dtype_byte_size
|
||||
from accelerate.utils import compute_module_sizes
|
||||
from huggingface_hub import ModelCard, delete_repo, snapshot_download
|
||||
from huggingface_hub.utils import is_jinja_available
|
||||
from parameterized import parameterized
|
||||
@@ -115,72 +113,6 @@ def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
|
||||
out_queue.join()
|
||||
|
||||
|
||||
def named_persistent_module_tensors(
|
||||
module: nn.Module,
|
||||
recurse: bool = False,
|
||||
):
|
||||
"""
|
||||
A helper function that gathers all the tensors (parameters + persistent buffers) of a given module.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module we want the tensors on.
|
||||
recurse (`bool`, *optional`, defaults to `False`):
|
||||
Whether or not to go look in every submodule or just return the direct parameters and buffers.
|
||||
"""
|
||||
yield from module.named_parameters(recurse=recurse)
|
||||
|
||||
for named_buffer in module.named_buffers(recurse=recurse):
|
||||
name, _ = named_buffer
|
||||
# Get parent by splitting on dots and traversing the model
|
||||
parent = module
|
||||
if "." in name:
|
||||
parent_name = name.rsplit(".", 1)[0]
|
||||
for part in parent_name.split("."):
|
||||
parent = getattr(parent, part)
|
||||
name = name.split(".")[-1]
|
||||
if name not in parent._non_persistent_buffers_set:
|
||||
yield named_buffer
|
||||
|
||||
|
||||
def compute_module_persistent_sizes(
|
||||
model: nn.Module,
|
||||
dtype: Optional[Union[str, torch.device]] = None,
|
||||
special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None,
|
||||
):
|
||||
"""
|
||||
Compute the size of each submodule of a given model (parameters + persistent buffers).
|
||||
"""
|
||||
if dtype is not None:
|
||||
dtype = _get_proper_dtype(dtype)
|
||||
dtype_size = dtype_byte_size(dtype)
|
||||
if special_dtypes is not None:
|
||||
special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()}
|
||||
special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()}
|
||||
module_sizes = defaultdict(int)
|
||||
|
||||
module_list = []
|
||||
|
||||
module_list = named_persistent_module_tensors(model, recurse=True)
|
||||
|
||||
for name, tensor in module_list:
|
||||
if special_dtypes is not None and name in special_dtypes:
|
||||
size = tensor.numel() * special_dtypes_size[name]
|
||||
elif dtype is None:
|
||||
size = tensor.numel() * dtype_byte_size(tensor.dtype)
|
||||
elif str(tensor.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
|
||||
# According to the code in set_module_tensor_to_device, these types won't be converted
|
||||
# so use their original size here
|
||||
size = tensor.numel() * dtype_byte_size(tensor.dtype)
|
||||
else:
|
||||
size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype))
|
||||
name_parts = name.split(".")
|
||||
for idx in range(len(name_parts) + 1):
|
||||
module_sizes[".".join(name_parts[:idx])] += size
|
||||
|
||||
return module_sizes
|
||||
|
||||
|
||||
class ModelUtilsTest(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
@@ -1080,7 +1012,7 @@ class ModelTesterMixin:
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
# We test several splits of sizes to make sure it works.
|
||||
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
@@ -1110,7 +1042,7 @@ class ModelTesterMixin:
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir, safe_serialization=False)
|
||||
|
||||
@@ -1144,7 +1076,7 @@ class ModelTesterMixin:
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir)
|
||||
|
||||
@@ -1172,7 +1104,7 @@ class ModelTesterMixin:
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
# We test several splits of sizes to make sure it works.
|
||||
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
@@ -1200,7 +1132,7 @@ class ModelTesterMixin:
|
||||
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
|
||||
@@ -1232,7 +1164,7 @@ class ModelTesterMixin:
|
||||
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
|
||||
variant = "fp16"
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
@@ -1272,7 +1204,7 @@ class ModelTesterMixin:
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
|
||||
@@ -1301,7 +1233,7 @@ class ModelTesterMixin:
|
||||
config, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**config).eval()
|
||||
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
|
||||
variant = "fp16"
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
|
||||
@@ -30,8 +30,6 @@ class MochiTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = MochiTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
# Overriding it because of the transformer size.
|
||||
model_split_percents = [0.7, 0.6, 0.6]
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers import SanaTransformer2DModel
|
||||
@@ -81,27 +80,3 @@ class SanaTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"SanaTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
condition=torch.device(torch_device).type == "cuda",
|
||||
reason="Test currently fails.",
|
||||
strict=True,
|
||||
)
|
||||
def test_cpu_offload(self):
|
||||
return super().test_cpu_offload()
|
||||
|
||||
@pytest.mark.xfail(
|
||||
condition=torch.device(torch_device).type == "cuda",
|
||||
reason="Test currently fails.",
|
||||
strict=True,
|
||||
)
|
||||
def test_disk_offload_with_safetensors(self):
|
||||
return super().test_disk_offload_with_safetensors()
|
||||
|
||||
@pytest.mark.xfail(
|
||||
condition=torch.device(torch_device).type == "cuda",
|
||||
reason="Test currently fails.",
|
||||
strict=True,
|
||||
)
|
||||
def test_disk_offload_without_safetensors(self):
|
||||
return super().test_disk_offload_without_safetensors()
|
||||
|
||||
@@ -63,19 +63,10 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLLTXVideo(
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
latent_channels=8,
|
||||
block_out_channels=(8, 8, 8, 8),
|
||||
decoder_block_out_channels=(8, 8, 8, 8),
|
||||
layers_per_block=(1, 1, 1, 1, 1),
|
||||
decoder_layers_per_block=(1, 1, 1, 1, 1),
|
||||
spatio_temporal_scaling=(True, True, False, False),
|
||||
decoder_spatio_temporal_scaling=(True, True, False, False),
|
||||
decoder_inject_noise=(False, False, False, False, False),
|
||||
upsample_residual=(False, False, False, False),
|
||||
upsample_factor=(1, 1, 1, 1),
|
||||
timestep_conditioning=False,
|
||||
layers_per_block=(1, 1, 1, 1, 1),
|
||||
patch_size=1,
|
||||
patch_size_t=1,
|
||||
encoder_causal=True,
|
||||
|
||||
@@ -68,19 +68,10 @@ class LTXImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLLTXVideo(
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
latent_channels=8,
|
||||
block_out_channels=(8, 8, 8, 8),
|
||||
decoder_block_out_channels=(8, 8, 8, 8),
|
||||
layers_per_block=(1, 1, 1, 1, 1),
|
||||
decoder_layers_per_block=(1, 1, 1, 1, 1),
|
||||
spatio_temporal_scaling=(True, True, False, False),
|
||||
decoder_spatio_temporal_scaling=(True, True, False, False),
|
||||
decoder_inject_noise=(False, False, False, False, False),
|
||||
upsample_residual=(False, False, False, False),
|
||||
upsample_factor=(1, 1, 1, 1),
|
||||
timestep_conditioning=False,
|
||||
layers_per_block=(1, 1, 1, 1, 1),
|
||||
patch_size=1,
|
||||
patch_size_t=1,
|
||||
encoder_causal=True,
|
||||
|
||||
@@ -18,7 +18,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer
|
||||
from transformers import Gemma2Config, Gemma2ForCausalLM, GemmaTokenizer
|
||||
|
||||
from diffusers import AutoencoderDC, FlowMatchEulerDiscreteScheduler, SanaPipeline, SanaTransformer2DModel
|
||||
from diffusers.utils.testing_utils import (
|
||||
@@ -101,7 +101,7 @@ class SanaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = Gemma2Config(
|
||||
head_dim=16,
|
||||
hidden_size=8,
|
||||
hidden_size=32,
|
||||
initializer_range=0.02,
|
||||
intermediate_size=64,
|
||||
max_position_embeddings=8192,
|
||||
@@ -112,7 +112,7 @@ class SanaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
vocab_size=8,
|
||||
attn_implementation="eager",
|
||||
)
|
||||
text_encoder = Gemma2Model(text_encoder_config)
|
||||
text_encoder = Gemma2ForCausalLM(text_encoder_config)
|
||||
tokenizer = GemmaTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")
|
||||
|
||||
components = {
|
||||
|
||||
@@ -36,7 +36,7 @@ from diffusers.utils.testing_utils import (
|
||||
nightly,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torchao_version_greater_or_equal,
|
||||
require_torchao_version_greater,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@@ -74,13 +74,13 @@ if is_torch_available():
|
||||
|
||||
if is_torchao_available():
|
||||
from torchao.dtypes import AffineQuantizedTensor
|
||||
from torchao.dtypes.affine_quantized_tensor import TensorCoreTiledLayoutType
|
||||
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
|
||||
from torchao.utils import get_model_size_in_bytes
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torch_gpu
|
||||
@require_torchao_version_greater_or_equal("0.7.0")
|
||||
@require_torchao_version_greater("0.6.0")
|
||||
class TorchAoConfigTest(unittest.TestCase):
|
||||
def test_to_dict(self):
|
||||
"""
|
||||
@@ -125,7 +125,7 @@ class TorchAoConfigTest(unittest.TestCase):
|
||||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
||||
@require_torch
|
||||
@require_torch_gpu
|
||||
@require_torchao_version_greater_or_equal("0.7.0")
|
||||
@require_torchao_version_greater("0.6.0")
|
||||
class TorchAoTest(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
@@ -139,13 +139,11 @@ class TorchAoTest(unittest.TestCase):
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained(
|
||||
model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16
|
||||
)
|
||||
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder")
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
|
||||
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
|
||||
tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2")
|
||||
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16)
|
||||
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae")
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
return {
|
||||
@@ -214,7 +212,7 @@ class TorchAoTest(unittest.TestCase):
|
||||
def _test_quant_type(self, quantization_config: TorchAoConfig, expected_slice: List[float]):
|
||||
components = self.get_dummy_components(quantization_config)
|
||||
pipe = FluxPipeline(**components)
|
||||
pipe.to(device=torch_device)
|
||||
pipe.to(device=torch_device, dtype=torch.bfloat16)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output = pipe(**inputs)[0]
|
||||
@@ -278,6 +276,7 @@ class TorchAoTest(unittest.TestCase):
|
||||
self.assertTrue(isinstance(weight, AffineQuantizedTensor))
|
||||
self.assertEqual(weight.quant_min, 0)
|
||||
self.assertEqual(weight.quant_max, 15)
|
||||
self.assertTrue(isinstance(weight.layout_type, TensorCoreTiledLayoutType))
|
||||
|
||||
def test_device_map(self):
|
||||
"""
|
||||
@@ -342,22 +341,6 @@ class TorchAoTest(unittest.TestCase):
|
||||
|
||||
def test_modules_to_not_convert(self):
|
||||
quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"])
|
||||
quantized_model_with_not_convert = FluxTransformer2DModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-pipe",
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
unquantized_layer = quantized_model_with_not_convert.transformer_blocks[0].ff.net[2]
|
||||
self.assertTrue(isinstance(unquantized_layer, torch.nn.Linear))
|
||||
self.assertFalse(isinstance(unquantized_layer.weight, AffineQuantizedTensor))
|
||||
self.assertEqual(unquantized_layer.weight.dtype, torch.bfloat16)
|
||||
|
||||
quantized_layer = quantized_model_with_not_convert.proj_out
|
||||
self.assertTrue(isinstance(quantized_layer.weight, AffineQuantizedTensor))
|
||||
|
||||
quantization_config = TorchAoConfig("int8_weight_only")
|
||||
quantized_model = FluxTransformer2DModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-pipe",
|
||||
subfolder="transformer",
|
||||
@@ -365,10 +348,14 @@ class TorchAoTest(unittest.TestCase):
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
size_quantized_with_not_convert = get_model_size_in_bytes(quantized_model_with_not_convert)
|
||||
size_quantized = get_model_size_in_bytes(quantized_model)
|
||||
unquantized_layer = quantized_model.transformer_blocks[0].ff.net[2]
|
||||
self.assertTrue(isinstance(unquantized_layer, torch.nn.Linear))
|
||||
self.assertFalse(isinstance(unquantized_layer.weight, AffineQuantizedTensor))
|
||||
self.assertEqual(unquantized_layer.weight.dtype, torch.bfloat16)
|
||||
|
||||
self.assertTrue(size_quantized < size_quantized_with_not_convert)
|
||||
quantized_layer = quantized_model.proj_out
|
||||
self.assertTrue(isinstance(quantized_layer.weight, AffineQuantizedTensor))
|
||||
self.assertEqual(quantized_layer.weight.layout_tensor.data.dtype, torch.int8)
|
||||
|
||||
def test_training(self):
|
||||
quantization_config = TorchAoConfig("int8_weight_only")
|
||||
@@ -419,6 +406,23 @@ class TorchAoTest(unittest.TestCase):
|
||||
# Note: Seems to require higher tolerance
|
||||
self.assertTrue(np.allclose(normal_output, compile_output, atol=1e-2, rtol=1e-3))
|
||||
|
||||
@staticmethod
|
||||
def _get_memory_footprint(module):
|
||||
quantized_param_memory = 0.0
|
||||
unquantized_param_memory = 0.0
|
||||
|
||||
for param in module.parameters():
|
||||
if param.__class__.__name__ == "AffineQuantizedTensor":
|
||||
data, scale, zero_point = param.layout_tensor.get_plain()
|
||||
quantized_param_memory += data.numel() + data.element_size()
|
||||
quantized_param_memory += scale.numel() + scale.element_size()
|
||||
quantized_param_memory += zero_point.numel() + zero_point.element_size()
|
||||
else:
|
||||
unquantized_param_memory += param.data.numel() * param.data.element_size()
|
||||
|
||||
total_memory = quantized_param_memory + unquantized_param_memory
|
||||
return total_memory, quantized_param_memory, unquantized_param_memory
|
||||
|
||||
def test_memory_footprint(self):
|
||||
r"""
|
||||
A simple test to check if the model conversion has been done correctly by checking on the
|
||||
@@ -429,37 +433,41 @@ class TorchAoTest(unittest.TestCase):
|
||||
transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"))["transformer"]
|
||||
transformer_bf16 = self.get_dummy_components(None)["transformer"]
|
||||
|
||||
total_int4wo = get_model_size_in_bytes(transformer_int4wo)
|
||||
total_int4wo_gs32 = get_model_size_in_bytes(transformer_int4wo_gs32)
|
||||
total_int8wo = get_model_size_in_bytes(transformer_int8wo)
|
||||
total_bf16 = get_model_size_in_bytes(transformer_bf16)
|
||||
total_int4wo, quantized_int4wo, unquantized_int4wo = self._get_memory_footprint(transformer_int4wo)
|
||||
total_int4wo_gs32, quantized_int4wo_gs32, unquantized_int4wo_gs32 = self._get_memory_footprint(
|
||||
transformer_int4wo_gs32
|
||||
)
|
||||
total_int8wo, quantized_int8wo, unquantized_int8wo = self._get_memory_footprint(transformer_int8wo)
|
||||
total_bf16, quantized_bf16, unquantized_bf16 = self._get_memory_footprint(transformer_bf16)
|
||||
|
||||
# Latter has smaller group size, so more groups -> more scales and zero points
|
||||
self.assertTrue(total_int4wo < total_int4wo_gs32)
|
||||
self.assertTrue(quantized_bf16 == 0 and total_bf16 == unquantized_bf16)
|
||||
# int4wo_gs32 has smaller group size, so more groups -> more scales and zero points
|
||||
self.assertTrue(total_int8wo < total_bf16 < total_int4wo_gs32)
|
||||
# int4 with default group size quantized very few linear layers compared to a smaller group size of 32
|
||||
self.assertTrue(quantized_int4wo < quantized_int4wo_gs32 and unquantized_int4wo > unquantized_int4wo_gs32)
|
||||
# int8 quantizes more layers compare to int4 with default group size
|
||||
self.assertTrue(total_int8wo < total_int4wo)
|
||||
# int4wo does not quantize too many layers because of default group size, but for the layers it does
|
||||
# there is additional overhead of scales and zero points
|
||||
self.assertTrue(total_bf16 < total_int4wo)
|
||||
self.assertTrue(quantized_int8wo < quantized_int4wo)
|
||||
|
||||
def test_wrong_config(self):
|
||||
with self.assertRaises(ValueError):
|
||||
self.get_dummy_components(TorchAoConfig("int42"))
|
||||
|
||||
|
||||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
||||
# This class is not to be run as a test by itself. See the tests that follow this class
|
||||
@require_torch
|
||||
@require_torch_gpu
|
||||
@require_torchao_version_greater_or_equal("0.7.0")
|
||||
@require_torchao_version_greater("0.6.0")
|
||||
class TorchAoSerializationTest(unittest.TestCase):
|
||||
model_name = "hf-internal-testing/tiny-flux-pipe"
|
||||
quant_method, quant_method_kwargs = None, None
|
||||
device = "cuda"
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_dummy_model(self, quant_method, quant_method_kwargs, device=None):
|
||||
quantization_config = TorchAoConfig(quant_method, **quant_method_kwargs)
|
||||
def get_dummy_model(self, device=None):
|
||||
quantization_config = TorchAoConfig(self.quant_method, **self.quant_method_kwargs)
|
||||
quantized_model = FluxTransformer2DModel.from_pretrained(
|
||||
self.model_name,
|
||||
subfolder="transformer",
|
||||
@@ -495,15 +503,15 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
"timestep": timestep,
|
||||
}
|
||||
|
||||
def _test_original_model_expected_slice(self, quant_method, quant_method_kwargs, expected_slice):
|
||||
quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, torch_device)
|
||||
def test_original_model_expected_slice(self):
|
||||
quantized_model = self.get_dummy_model(torch_device)
|
||||
inputs = self.get_dummy_tensor_inputs(torch_device)
|
||||
output = quantized_model(**inputs)[0]
|
||||
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
|
||||
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
|
||||
self.assertTrue(np.allclose(output_slice, self.expected_slice, atol=1e-3, rtol=1e-3))
|
||||
|
||||
def _check_serialization_expected_slice(self, quant_method, quant_method_kwargs, expected_slice, device):
|
||||
quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, device)
|
||||
def check_serialization_expected_slice(self, expected_slice):
|
||||
quantized_model = self.get_dummy_model(self.device)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
quantized_model.save_pretrained(tmp_dir, safe_serialization=False)
|
||||
@@ -522,39 +530,42 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
)
|
||||
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
|
||||
|
||||
def test_int_a8w8_cuda(self):
|
||||
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
|
||||
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
|
||||
device = "cuda"
|
||||
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
|
||||
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
|
||||
def test_serialization_expected_slice(self):
|
||||
self.check_serialization_expected_slice(self.serialized_expected_slice)
|
||||
|
||||
def test_int_a16w8_cuda(self):
|
||||
quant_method, quant_method_kwargs = "int8_weight_only", {}
|
||||
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
|
||||
device = "cuda"
|
||||
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
|
||||
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
|
||||
|
||||
def test_int_a8w8_cpu(self):
|
||||
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
|
||||
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
|
||||
device = "cpu"
|
||||
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
|
||||
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
|
||||
class TorchAoSerializationINTA8W8Test(TorchAoSerializationTest):
|
||||
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
|
||||
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
|
||||
serialized_expected_slice = expected_slice
|
||||
device = "cuda"
|
||||
|
||||
def test_int_a16w8_cpu(self):
|
||||
quant_method, quant_method_kwargs = "int8_weight_only", {}
|
||||
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
|
||||
device = "cpu"
|
||||
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
|
||||
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
|
||||
|
||||
class TorchAoSerializationINTA16W8Test(TorchAoSerializationTest):
|
||||
quant_method, quant_method_kwargs = "int8_weight_only", {}
|
||||
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
|
||||
serialized_expected_slice = expected_slice
|
||||
device = "cuda"
|
||||
|
||||
|
||||
class TorchAoSerializationINTA8W8CPUTest(TorchAoSerializationTest):
|
||||
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
|
||||
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
|
||||
serialized_expected_slice = expected_slice
|
||||
device = "cpu"
|
||||
|
||||
|
||||
class TorchAoSerializationINTA16W8CPUTest(TorchAoSerializationTest):
|
||||
quant_method, quant_method_kwargs = "int8_weight_only", {}
|
||||
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
|
||||
serialized_expected_slice = expected_slice
|
||||
device = "cpu"
|
||||
|
||||
|
||||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
||||
@require_torch
|
||||
@require_torch_gpu
|
||||
@require_torchao_version_greater_or_equal("0.7.0")
|
||||
@require_torchao_version_greater("0.6.0")
|
||||
@slow
|
||||
@nightly
|
||||
class SlowTorchAoTests(unittest.TestCase):
|
||||
@@ -570,13 +581,11 @@ class SlowTorchAoTests(unittest.TestCase):
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained(
|
||||
model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16
|
||||
)
|
||||
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder")
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
|
||||
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
|
||||
tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2")
|
||||
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16)
|
||||
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae")
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
return {
|
||||
@@ -608,7 +617,7 @@ class SlowTorchAoTests(unittest.TestCase):
|
||||
|
||||
def _test_quant_type(self, quantization_config, expected_slice):
|
||||
components = self.get_dummy_components(quantization_config)
|
||||
pipe = FluxPipeline(**components)
|
||||
pipe = FluxPipeline(**components).to(dtype=torch.bfloat16)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
Reference in New Issue
Block a user