Compare commits
23 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e8aacda762 | |||
| 12184f4015 | |||
| 6e1d2da194 | |||
| 11b1151840 | |||
| cd4d0d8ffb | |||
| 4b557132ce | |||
| 851dfa30ae | |||
| ea1ba0ba53 | |||
| 9d27df8071 | |||
| 055d95543a | |||
| 71cc2013fe | |||
| c34fc34563 | |||
| 5fcee4a447 | |||
| 76e2727b5c | |||
| 02c777c065 | |||
| 6a970a45c5 | |||
| ffc0eaab6d | |||
| 3c2e2aa8a9 | |||
| b58868e6f4 | |||
| da21d590b5 | |||
| 7c2f0afb1c | |||
| f615f00f58 | |||
| 6aaa0518e3 |
@@ -359,6 +359,8 @@ jobs:
|
||||
test_location: "bnb"
|
||||
- backend: "gguf"
|
||||
test_location: "gguf"
|
||||
- backend: "torchao"
|
||||
test_location: "torchao"
|
||||
runs-on:
|
||||
group: aws-g6e-xlarge-plus
|
||||
container:
|
||||
|
||||
@@ -68,7 +68,7 @@ jobs:
|
||||
- name: Test installing diffusers and importing
|
||||
run: |
|
||||
pip install diffusers && pip uninstall diffusers -y
|
||||
pip install -i https://testpypi.python.org/pypi diffusers
|
||||
pip install -i https://test.pypi.org/simple/ diffusers
|
||||
python -c "from diffusers import __version__; print(__version__)"
|
||||
python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('fusing/unet-ldm-dummy-update'); pipe()"
|
||||
python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('hf-internal-testing/tiny-stable-diffusion-pipe', safety_checker=None); pipe('ah suh du')"
|
||||
|
||||
@@ -429,7 +429,7 @@
|
||||
- local: api/pipelines/ledits_pp
|
||||
title: LEDITS++
|
||||
- local: api/pipelines/ltx_video
|
||||
title: LTX
|
||||
title: LTXVideo
|
||||
- local: api/pipelines/lumina
|
||||
title: Lumina-T2X
|
||||
- local: api/pipelines/marigold
|
||||
|
||||
@@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
|
||||
```python
|
||||
from diffusers import AutoencoderKLHunyuanVideo
|
||||
|
||||
vae = AutoencoderKLHunyuanVideo.from_pretrained("tencent/HunyuanVideo", torch_dtype=torch.float16)
|
||||
vae = AutoencoderKLHunyuanVideo.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder="vae", torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
## AutoencoderKLHunyuanVideo
|
||||
|
||||
@@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
|
||||
```python
|
||||
from diffusers import AutoencoderKLLTXVideo
|
||||
|
||||
vae = AutoencoderKLLTXVideo.from_pretrained("TODO/TODO", subfolder="vae", torch_dtype=torch.float32).to("cuda")
|
||||
vae = AutoencoderKLLTXVideo.from_pretrained("Lightricks/LTX-Video", subfolder="vae", torch_dtype=torch.float32).to("cuda")
|
||||
```
|
||||
|
||||
## AutoencoderKLLTXVideo
|
||||
|
||||
@@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
|
||||
```python
|
||||
from diffusers import HunyuanVideoTransformer3DModel
|
||||
|
||||
transformer = HunyuanVideoTransformer3DModel.from_pretrained("tencent/HunyuanVideo", torch_dtype=torch.bfloat16)
|
||||
transformer = HunyuanVideoTransformer3DModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## HunyuanVideoTransformer3DModel
|
||||
|
||||
@@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
|
||||
```python
|
||||
from diffusers import LTXVideoTransformer3DModel
|
||||
|
||||
transformer = LTXVideoTransformer3DModel.from_pretrained("TODO/TODO", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
|
||||
transformer = LTXVideoTransformer3DModel.from_pretrained("Lightricks/LTX-Video", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
|
||||
```
|
||||
|
||||
## LTXVideoTransformer3DModel
|
||||
|
||||
@@ -22,7 +22,7 @@ The model can be loaded with the following code snippet.
|
||||
```python
|
||||
from diffusers import SanaTransformer2DModel
|
||||
|
||||
transformer = SanaTransformer2DModel.from_pretrained("Efficient-Large-Model/Sana_1600M_1024px_diffusers", subfolder="transformer", torch_dtype=torch.float16)
|
||||
transformer = SanaTransformer2DModel.from_pretrained("Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## SanaTransformer2DModel
|
||||
|
||||
@@ -29,7 +29,7 @@ Recommendations for inference:
|
||||
- Transformer should be in `torch.bfloat16`.
|
||||
- VAE should be in `torch.float16`.
|
||||
- `num_frames` should be of the form `4 * k + 1`, for example `49` or `129`.
|
||||
- For smaller resolution images, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution images, try higher values (between `7.0` and `12.0`). The default value is `7.0` for HunyuanVideo.
|
||||
- For smaller resolution videos, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution images, try higher values (between `7.0` and `12.0`). The default value is `7.0` for HunyuanVideo.
|
||||
- For more information about supported resolutions and other details, please refer to the original repository [here](https://github.com/Tencent/HunyuanVideo/).
|
||||
|
||||
## HunyuanVideoPipeline
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# LTX
|
||||
# LTX Video
|
||||
|
||||
[LTX Video](https://huggingface.co/Lightricks/LTX-Video) is the first DiT-based video generation model capable of generating high-quality videos in real-time. It produces 24 FPS videos at a 768x512 resolution faster than they can be watched. Trained on a large-scale dataset of diverse videos, the model generates high-resolution videos with realistic and varied content. We provide a model for both text-to-video as well as image + text-to-video usecases.
|
||||
|
||||
@@ -22,14 +22,24 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
|
||||
|
||||
</Tip>
|
||||
|
||||
Available models:
|
||||
|
||||
| Model name | Recommended dtype |
|
||||
|:-------------:|:-----------------:|
|
||||
| [`LTX Video 0.9.0`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.safetensors) | `torch.bfloat16` |
|
||||
| [`LTX Video 0.9.1`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) | `torch.bfloat16` |
|
||||
|
||||
Note: The recommended dtype is for the transformer component. The VAE and text encoders can be either `torch.float32`, `torch.bfloat16` or `torch.float16` but the recommended dtype is `torch.bfloat16` as used in the original repository.
|
||||
|
||||
## Loading Single Files
|
||||
|
||||
Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`].
|
||||
Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`]. We recommend using `from_single_file` for the Lightricks series of models, as they plan to release multiple models in the future in the single file format.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import AutoencoderKLLTXVideo, LTXImageToVideoPipeline, LTXVideoTransformer3DModel
|
||||
|
||||
# `single_file_url` could also be https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.1.safetensors
|
||||
single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.safetensors"
|
||||
transformer = LTXVideoTransformer3DModel.from_single_file(
|
||||
single_file_url, torch_dtype=torch.bfloat16
|
||||
@@ -99,6 +109,34 @@ export_to_video(video, "output_gguf_ltx.mp4", fps=24)
|
||||
|
||||
Make sure to read the [documentation on GGUF](../../quantization/gguf) to learn more about our GGUF support.
|
||||
|
||||
<!-- TODO(aryan): Update this when official weights are supported -->
|
||||
|
||||
Loading and running inference with [LTX Video 0.9.1](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) weights.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import LTXPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
pipe = LTXPipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.1-diffusers", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
|
||||
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
|
||||
|
||||
video = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
width=768,
|
||||
height=512,
|
||||
num_frames=161,
|
||||
decode_timestep=0.03,
|
||||
decode_noise_scale=0.025,
|
||||
num_inference_steps=50,
|
||||
).frames[0]
|
||||
export_to_video(video, "output.mp4", fps=24)
|
||||
```
|
||||
|
||||
Refer to [this section](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox#memory-optimization) to learn more about optimizing memory consumption.
|
||||
|
||||
## LTXPipeline
|
||||
|
||||
@@ -32,9 +32,9 @@ Available models:
|
||||
|
||||
| Model | Recommended dtype |
|
||||
|:-----:|:-----------------:|
|
||||
| [`Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers) | `torch.bfloat16` |
|
||||
| [`Efficient-Large-Model/Sana_1600M_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_diffusers) | `torch.float16` |
|
||||
| [`Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers) | `torch.float16` |
|
||||
| [`Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers) | `torch.bfloat16` |
|
||||
| [`Efficient-Large-Model/Sana_1600M_512px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_diffusers) | `torch.float16` |
|
||||
| [`Efficient-Large-Model/Sana_1600M_512px_MultiLing_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_MultiLing_diffusers) | `torch.float16` |
|
||||
| [`Efficient-Large-Model/Sana_600M_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_600M_1024px_diffusers) | `torch.float16` |
|
||||
|
||||
@@ -25,9 +25,10 @@ Quantize a model by passing [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`]
|
||||
The example below only quantizes the weights to int8.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
|
||||
|
||||
model_id = "black-forest-labs/Flux.1-Dev"
|
||||
model_id = "black-forest-labs/FLUX.1-dev"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
quantization_config = TorchAoConfig("int8wo")
|
||||
@@ -44,8 +45,14 @@ pipe = FluxPipeline.from_pretrained(
|
||||
)
|
||||
pipe.to("cuda")
|
||||
|
||||
# Without quantization: ~31.447 GB
|
||||
# With quantization: ~20.40 GB
|
||||
print(f"Pipeline memory usage: {torch.cuda.max_memory_reserved() / 1024**3:.3f} GB")
|
||||
|
||||
prompt = "A cat holding a sign that says hello world"
|
||||
image = pipe(prompt, num_inference_steps=28, guidance_scale=0.0).images[0]
|
||||
image = pipe(
|
||||
prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
|
||||
).images[0]
|
||||
image.save("output.png")
|
||||
```
|
||||
|
||||
@@ -86,6 +93,63 @@ Some quantization methods are aliases (for example, `int8wo` is the commonly use
|
||||
|
||||
Refer to the official torchao documentation for a better understanding of the available quantization methods and the exhaustive list of configuration options available.
|
||||
|
||||
## Serializing and Deserializing quantized models
|
||||
|
||||
To serialize a quantized model in a given dtype, first load the model with the desired quantization dtype and then save it using the [`~ModelMixin.save_pretrained`] method.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxTransformer2DModel, TorchAoConfig
|
||||
|
||||
quantization_config = TorchAoConfig("int8wo")
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/Flux.1-Dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
transformer.save_pretrained("/path/to/flux_int8wo", safe_serialization=False)
|
||||
```
|
||||
|
||||
To load a serialized quantized model, use the [`~ModelMixin.from_pretrained`] method.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxPipeline, FluxTransformer2DModel
|
||||
|
||||
transformer = FluxTransformer2DModel.from_pretrained("/path/to/flux_int8wo", torch_dtype=torch.bfloat16, use_safetensors=False)
|
||||
pipe = FluxPipeline.from_pretrained("black-forest-labs/Flux.1-Dev", transformer=transformer, torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = "A cat holding a sign that says hello world"
|
||||
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0).images[0]
|
||||
image.save("output.png")
|
||||
```
|
||||
|
||||
Some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
|
||||
|
||||
# Serialize the model
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/Flux.1-Dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=TorchAoConfig("uint4wo"),
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
transformer.save_pretrained("/path/to/flux_uint4wo", safe_serialization=False, max_shard_size="50GB")
|
||||
# ...
|
||||
|
||||
# Load the model
|
||||
state_dict = torch.load("/path/to/flux_uint4wo/diffusion_pytorch_model.bin", weights_only=False, map_location="cpu")
|
||||
with init_empty_weights():
|
||||
transformer = FluxTransformer2DModel.from_config("/path/to/flux_uint4wo/config.json")
|
||||
transformer.load_state_dict(state_dict, strict=True, assign=True)
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- [TorchAO Quantization API](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md)
|
||||
|
||||
@@ -74,7 +74,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -73,7 +73,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -43,8 +43,7 @@ from diffusers.utils import BaseOutput, check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
check_min_version("0.32.0")
|
||||
|
||||
class MarigoldDepthOutput(BaseOutput):
|
||||
"""
|
||||
|
||||
@@ -73,7 +73,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -78,7 +78,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
|
||||
@@ -59,7 +59,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
|
||||
@@ -63,7 +63,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -0,0 +1,206 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import safetensors
|
||||
|
||||
|
||||
sys.path.append("..")
|
||||
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
class DreamBoothLoRASANA(ExamplesTestsAccelerate):
|
||||
instance_data_dir = "docs/source/en/imgs"
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-sana-pipe"
|
||||
script_path = "examples/dreambooth/train_dreambooth_lora_sana.py"
|
||||
transformer_layer_type = "transformer_blocks.0.attn1.to_k"
|
||||
|
||||
def test_dreambooth_lora_sana(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--resolution 32
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--max_sequence_length 16
|
||||
""".split()
|
||||
|
||||
test_args.extend(["--instance_prompt", ""])
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names.
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_latent_caching(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--resolution 32
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--cache_latents
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--max_sequence_length 16
|
||||
""".split()
|
||||
|
||||
test_args.extend(["--instance_prompt", ""])
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names.
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_layers(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--resolution 32
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--cache_latents
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lora_layers {self.transformer_layer_type}
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--max_sequence_length 16
|
||||
""".split()
|
||||
|
||||
test_args.extend(["--instance_prompt", ""])
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names. In this test, we only params of
|
||||
# `self.transformer_layer_type` should be in the state dict.
|
||||
starts_with_transformer = all(self.transformer_layer_type in key for key in lora_state_dict)
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_sana_checkpointing_checkpoints_total_limit(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--resolution=32
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=6
|
||||
--checkpoints_total_limit=2
|
||||
--checkpointing_steps=2
|
||||
--max_sequence_length 16
|
||||
""".split()
|
||||
|
||||
test_args.extend(["--instance_prompt", ""])
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_dreambooth_lora_sana_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--resolution=32
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=4
|
||||
--checkpointing_steps=2
|
||||
--max_sequence_length 166
|
||||
""".split()
|
||||
|
||||
test_args.extend(["--instance_prompt", ""])
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
|
||||
|
||||
resume_run_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--resolution=32
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=8
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-4
|
||||
--checkpoints_total_limit=2
|
||||
--max_sequence_length 16
|
||||
""".split()
|
||||
|
||||
resume_run_args.extend(["--instance_prompt", ""])
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
|
||||
@@ -63,7 +63,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
# Cache compiled models across invocations of this script.
|
||||
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
||||
|
||||
@@ -65,7 +65,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -70,7 +70,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -70,7 +70,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -943,7 +943,7 @@ def main(args):
|
||||
|
||||
# Load scheduler and models
|
||||
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="scheduler"
|
||||
args.pretrained_model_name_or_path, subfolder="scheduler", revision=args.revision
|
||||
)
|
||||
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
||||
text_encoder = Gemma2Model.from_pretrained(
|
||||
@@ -964,15 +964,6 @@ def main(args):
|
||||
vae.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
|
||||
# Initialize a text encoding pipeline and keep it to CPU for now.
|
||||
text_encoding_pipeline = SanaPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=None,
|
||||
transformer=None,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
|
||||
# as these weights are only used for inference, keeping weights in full precision is not required.
|
||||
weight_dtype = torch.float32
|
||||
@@ -993,6 +984,15 @@ def main(args):
|
||||
# because Gemma2 is particularly suited for bfloat16.
|
||||
text_encoder.to(dtype=torch.bfloat16)
|
||||
|
||||
# Initialize a text encoding pipeline and keep it to CPU for now.
|
||||
text_encoding_pipeline = SanaPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=None,
|
||||
transformer=None,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
transformer.enable_gradient_checkpointing()
|
||||
|
||||
@@ -1182,6 +1182,7 @@ def main(args):
|
||||
)
|
||||
if args.offload:
|
||||
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
|
||||
prompt_embeds = prompt_embeds.to(transformer.dtype)
|
||||
return prompt_embeds, prompt_attention_mask
|
||||
|
||||
# If no type of tuning is done on the text_encoder and custom instance prompts are NOT
|
||||
@@ -1216,7 +1217,7 @@ def main(args):
|
||||
vae_config_scaling_factor = vae.config.scaling_factor
|
||||
if args.cache_latents:
|
||||
latents_cache = []
|
||||
vae = vae.to("cuda")
|
||||
vae = vae.to(accelerator.device)
|
||||
for batch in tqdm(train_dataloader, desc="Caching latents"):
|
||||
with torch.no_grad():
|
||||
batch["pixel_values"] = batch["pixel_values"].to(
|
||||
|
||||
@@ -72,7 +72,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
|
||||
@@ -55,7 +55,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
|
||||
@@ -81,7 +81,7 @@ else:
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ else:
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -76,7 +76,7 @@ else:
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from safetensors.torch import load_file
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
|
||||
@@ -21,7 +23,9 @@ TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"vae": remove_keys_,
|
||||
}
|
||||
|
||||
VAE_KEYS_RENAME_DICT = {
|
||||
# decoder
|
||||
@@ -54,10 +58,31 @@ VAE_KEYS_RENAME_DICT = {
|
||||
"per_channel_statistics.std-of-means": "latents_std",
|
||||
}
|
||||
|
||||
VAE_091_RENAME_DICT = {
|
||||
# decoder
|
||||
"up_blocks.0": "mid_block",
|
||||
"up_blocks.1": "up_blocks.0.upsamplers.0",
|
||||
"up_blocks.2": "up_blocks.0",
|
||||
"up_blocks.3": "up_blocks.1.upsamplers.0",
|
||||
"up_blocks.4": "up_blocks.1",
|
||||
"up_blocks.5": "up_blocks.2.upsamplers.0",
|
||||
"up_blocks.6": "up_blocks.2",
|
||||
"up_blocks.7": "up_blocks.3.upsamplers.0",
|
||||
"up_blocks.8": "up_blocks.3",
|
||||
# common
|
||||
"last_time_embedder": "time_embedder",
|
||||
"last_scale_shift_table": "scale_shift_table",
|
||||
}
|
||||
|
||||
VAE_SPECIAL_KEYS_REMAP = {
|
||||
"per_channel_statistics.channel": remove_keys_,
|
||||
"per_channel_statistics.mean-of-means": remove_keys_,
|
||||
"per_channel_statistics.mean-of-stds": remove_keys_,
|
||||
"model.diffusion_model": remove_keys_,
|
||||
}
|
||||
|
||||
VAE_091_SPECIAL_KEYS_REMAP = {
|
||||
"timestep_scale_multiplier": remove_keys_,
|
||||
}
|
||||
|
||||
|
||||
@@ -80,13 +105,16 @@ def convert_transformer(
|
||||
ckpt_path: str,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
PREFIX_KEY = ""
|
||||
PREFIX_KEY = "model.diffusion_model."
|
||||
|
||||
original_state_dict = get_state_dict(load_file(ckpt_path))
|
||||
transformer = LTXVideoTransformer3DModel().to(dtype=dtype)
|
||||
with init_empty_weights():
|
||||
transformer = LTXVideoTransformer3DModel()
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[len(PREFIX_KEY) :]
|
||||
new_key = key[:]
|
||||
if new_key.startswith(PREFIX_KEY):
|
||||
new_key = key[len(PREFIX_KEY) :]
|
||||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
@@ -97,16 +125,21 @@ def convert_transformer(
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
transformer.load_state_dict(original_state_dict, strict=True)
|
||||
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return transformer
|
||||
|
||||
|
||||
def convert_vae(ckpt_path: str, dtype: torch.dtype):
|
||||
def convert_vae(ckpt_path: str, config, dtype: torch.dtype):
|
||||
PREFIX_KEY = "vae."
|
||||
|
||||
original_state_dict = get_state_dict(load_file(ckpt_path))
|
||||
vae = AutoencoderKLLTXVideo().to(dtype=dtype)
|
||||
with init_empty_weights():
|
||||
vae = AutoencoderKLLTXVideo(**config)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
if new_key.startswith(PREFIX_KEY):
|
||||
new_key = key[len(PREFIX_KEY) :]
|
||||
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
@@ -117,10 +150,60 @@ def convert_vae(ckpt_path: str, dtype: torch.dtype):
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
vae.load_state_dict(original_state_dict, strict=True)
|
||||
vae.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return vae
|
||||
|
||||
|
||||
def get_vae_config(version: str) -> Dict[str, Any]:
|
||||
if version == "0.9.0":
|
||||
config = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (128, 256, 512, 512),
|
||||
"decoder_block_out_channels": (128, 256, 512, 512),
|
||||
"layers_per_block": (4, 3, 3, 3, 4),
|
||||
"decoder_layers_per_block": (4, 3, 3, 3, 4),
|
||||
"spatio_temporal_scaling": (True, True, True, False),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True, False),
|
||||
"decoder_inject_noise": (False, False, False, False, False),
|
||||
"upsample_residual": (False, False, False, False),
|
||||
"upsample_factor": (1, 1, 1, 1),
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
"scaling_factor": 1.0,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
"timestep_conditioning": False,
|
||||
}
|
||||
elif version == "0.9.1":
|
||||
config = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (128, 256, 512, 512),
|
||||
"decoder_block_out_channels": (256, 512, 1024),
|
||||
"layers_per_block": (4, 3, 3, 3, 4),
|
||||
"decoder_layers_per_block": (5, 6, 7, 8),
|
||||
"spatio_temporal_scaling": (True, True, True, False),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True),
|
||||
"decoder_inject_noise": (True, True, True, False),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": True,
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
"scaling_factor": 1.0,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
}
|
||||
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
|
||||
VAE_SPECIAL_KEYS_REMAP.update(VAE_091_SPECIAL_KEYS_REMAP)
|
||||
return config
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
@@ -139,6 +222,9 @@ def get_args():
|
||||
parser.add_argument("--save_pipeline", action="store_true")
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
|
||||
parser.add_argument(
|
||||
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1"], help="Version of the LTX model"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -161,6 +247,7 @@ if __name__ == "__main__":
|
||||
transformer = None
|
||||
dtype = DTYPE_MAPPING[args.dtype]
|
||||
variant = VARIANT_MAPPING[args.dtype]
|
||||
output_path = Path(args.output_path)
|
||||
|
||||
if args.save_pipeline:
|
||||
assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None
|
||||
@@ -169,13 +256,14 @@ if __name__ == "__main__":
|
||||
transformer: LTXVideoTransformer3DModel = convert_transformer(args.transformer_ckpt_path, dtype)
|
||||
if not args.save_pipeline:
|
||||
transformer.save_pretrained(
|
||||
args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant
|
||||
output_path / "transformer", safe_serialization=True, max_shard_size="5GB", variant=variant
|
||||
)
|
||||
|
||||
if args.vae_ckpt_path is not None:
|
||||
vae: AutoencoderKLLTXVideo = convert_vae(args.vae_ckpt_path, dtype)
|
||||
config = get_vae_config(args.version)
|
||||
vae: AutoencoderKLLTXVideo = convert_vae(args.vae_ckpt_path, config, dtype)
|
||||
if not args.save_pipeline:
|
||||
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant)
|
||||
vae.save_pretrained(output_path / "vae", safe_serialization=True, max_shard_size="5GB", variant=variant)
|
||||
|
||||
if args.save_pipeline:
|
||||
text_encoder_id = "google/t5-v1_1-xxl"
|
||||
|
||||
@@ -88,13 +88,18 @@ def main(args):
|
||||
# y norm
|
||||
converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
|
||||
|
||||
# scheduler
|
||||
flow_shift = 3.0
|
||||
|
||||
# model config
|
||||
if args.model_type == "SanaMS_1600M_P1_D20":
|
||||
layer_num = 20
|
||||
elif args.model_type == "SanaMS_600M_P1_D28":
|
||||
layer_num = 28
|
||||
else:
|
||||
raise ValueError(f"{args.model_type} is not supported.")
|
||||
# Positional embedding interpolation scale.
|
||||
interpolation_scale = {512: None, 1024: None, 2048: 1.0}
|
||||
|
||||
for depth in range(layer_num):
|
||||
# Transformer blocks.
|
||||
@@ -176,6 +181,7 @@ def main(args):
|
||||
patch_size=1,
|
||||
norm_elementwise_affine=False,
|
||||
norm_eps=1e-6,
|
||||
interpolation_scale=interpolation_scale[args.image_size],
|
||||
)
|
||||
|
||||
if is_accelerate_available():
|
||||
|
||||
@@ -254,7 +254,7 @@ version_range_max = max(sys.version_info[1], 10) + 1
|
||||
|
||||
setup(
|
||||
name="diffusers",
|
||||
version="0.32.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="0.32.1", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
description="State-of-the-art diffusion in PyTorch and JAX.",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
__version__ = "0.32.0.dev0"
|
||||
__version__ = "0.32.1"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ from .single_file_utils import (
|
||||
convert_autoencoder_dc_checkpoint_to_diffusers,
|
||||
convert_controlnet_checkpoint,
|
||||
convert_flux_transformer_checkpoint_to_diffusers,
|
||||
convert_hunyuan_video_transformer_to_diffusers,
|
||||
convert_ldm_unet_checkpoint,
|
||||
convert_ldm_vae_checkpoint,
|
||||
convert_ltx_transformer_checkpoint_to_diffusers,
|
||||
@@ -101,6 +102,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"HunyuanVideoTransformer3DModel": {
|
||||
"checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -220,6 +225,7 @@ class FromOriginalModelMixin:
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
config_revision = kwargs.pop("config_revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
quantization_config = kwargs.pop("quantization_config", None)
|
||||
device = kwargs.pop("device", None)
|
||||
@@ -297,7 +303,7 @@ class FromOriginalModelMixin:
|
||||
subfolder=subfolder,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
revision=config_revision,
|
||||
)
|
||||
expected_kwargs, optional_kwargs = cls._get_signature_keys(cls)
|
||||
|
||||
|
||||
@@ -108,6 +108,7 @@ CHECKPOINT_KEY_NAMES = {
|
||||
"autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias",
|
||||
"autoencoder-dc-sana": "encoder.project_in.conv.bias",
|
||||
"mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"],
|
||||
"hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
|
||||
}
|
||||
|
||||
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
@@ -156,12 +157,14 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"},
|
||||
"flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
|
||||
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
|
||||
"ltx-video": {"pretrained_model_name_or_path": "Lightricks/LTX-Video"},
|
||||
"ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"},
|
||||
"ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"},
|
||||
"autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"},
|
||||
"autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"},
|
||||
"autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
|
||||
"autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"},
|
||||
"mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"},
|
||||
"hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"},
|
||||
}
|
||||
|
||||
# Use to configure model sample size when original config is provided
|
||||
@@ -603,7 +606,10 @@ def infer_diffusers_model_type(checkpoint):
|
||||
model_type = "flux-schnell"
|
||||
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx-video"]):
|
||||
model_type = "ltx-video"
|
||||
if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in checkpoint:
|
||||
model_type = "ltx-video-0.9.1"
|
||||
else:
|
||||
model_type = "ltx-video"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["autoencoder-dc"] in checkpoint:
|
||||
encoder_key = "encoder.project_in.conv.conv.bias"
|
||||
@@ -624,6 +630,9 @@ def infer_diffusers_model_type(checkpoint):
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["mochi-1-preview"]):
|
||||
model_type = "mochi-1-preview"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint:
|
||||
model_type = "hunyuan-video"
|
||||
|
||||
else:
|
||||
model_type = "v1"
|
||||
|
||||
@@ -2333,12 +2342,32 @@ def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
"per_channel_statistics.std-of-means": "latents_std",
|
||||
}
|
||||
|
||||
VAE_091_RENAME_DICT = {
|
||||
# decoder
|
||||
"up_blocks.0": "mid_block",
|
||||
"up_blocks.1": "up_blocks.0.upsamplers.0",
|
||||
"up_blocks.2": "up_blocks.0",
|
||||
"up_blocks.3": "up_blocks.1.upsamplers.0",
|
||||
"up_blocks.4": "up_blocks.1",
|
||||
"up_blocks.5": "up_blocks.2.upsamplers.0",
|
||||
"up_blocks.6": "up_blocks.2",
|
||||
"up_blocks.7": "up_blocks.3.upsamplers.0",
|
||||
"up_blocks.8": "up_blocks.3",
|
||||
# common
|
||||
"last_time_embedder": "time_embedder",
|
||||
"last_scale_shift_table": "scale_shift_table",
|
||||
}
|
||||
|
||||
VAE_SPECIAL_KEYS_REMAP = {
|
||||
"per_channel_statistics.channel": remove_keys_,
|
||||
"per_channel_statistics.mean-of-means": remove_keys_,
|
||||
"per_channel_statistics.mean-of-stds": remove_keys_,
|
||||
"timestep_scale_multiplier": remove_keys_,
|
||||
}
|
||||
|
||||
if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict:
|
||||
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
|
||||
|
||||
for key in list(converted_state_dict.keys()):
|
||||
new_key = key
|
||||
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
|
||||
@@ -2522,3 +2551,133 @@ def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
new_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies")
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def convert_hunyuan_video_transformer_to_diffusers(checkpoint, **kwargs):
|
||||
def remap_norm_scale_shift_(key, state_dict):
|
||||
weight = state_dict.pop(key)
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight
|
||||
|
||||
def remap_txt_in_(key, state_dict):
|
||||
def rename_key(key):
|
||||
new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
|
||||
new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
|
||||
new_key = new_key.replace("txt_in", "context_embedder")
|
||||
new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
|
||||
new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
|
||||
new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
|
||||
new_key = new_key.replace("mlp", "ff")
|
||||
return new_key
|
||||
|
||||
if "self_attn_qkv" in key:
|
||||
weight = state_dict.pop(key)
|
||||
to_q, to_k, to_v = weight.chunk(3, dim=0)
|
||||
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
|
||||
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
|
||||
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
|
||||
else:
|
||||
state_dict[rename_key(key)] = state_dict.pop(key)
|
||||
|
||||
def remap_img_attn_qkv_(key, state_dict):
|
||||
weight = state_dict.pop(key)
|
||||
to_q, to_k, to_v = weight.chunk(3, dim=0)
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v
|
||||
|
||||
def remap_txt_attn_qkv_(key, state_dict):
|
||||
weight = state_dict.pop(key)
|
||||
to_q, to_k, to_v = weight.chunk(3, dim=0)
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v
|
||||
|
||||
def remap_single_transformer_blocks_(key, state_dict):
|
||||
hidden_size = 3072
|
||||
|
||||
if "linear1.weight" in key:
|
||||
linear1_weight = state_dict.pop(key)
|
||||
split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
|
||||
q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
|
||||
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.weight")
|
||||
state_dict[f"{new_key}.attn.to_q.weight"] = q
|
||||
state_dict[f"{new_key}.attn.to_k.weight"] = k
|
||||
state_dict[f"{new_key}.attn.to_v.weight"] = v
|
||||
state_dict[f"{new_key}.proj_mlp.weight"] = mlp
|
||||
|
||||
elif "linear1.bias" in key:
|
||||
linear1_bias = state_dict.pop(key)
|
||||
split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
|
||||
q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
|
||||
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.bias")
|
||||
state_dict[f"{new_key}.attn.to_q.bias"] = q_bias
|
||||
state_dict[f"{new_key}.attn.to_k.bias"] = k_bias
|
||||
state_dict[f"{new_key}.attn.to_v.bias"] = v_bias
|
||||
state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias
|
||||
|
||||
else:
|
||||
new_key = key.replace("single_blocks", "single_transformer_blocks")
|
||||
new_key = new_key.replace("linear2", "proj_out")
|
||||
new_key = new_key.replace("q_norm", "attn.norm_q")
|
||||
new_key = new_key.replace("k_norm", "attn.norm_k")
|
||||
state_dict[new_key] = state_dict.pop(key)
|
||||
|
||||
TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"img_in": "x_embedder",
|
||||
"time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
|
||||
"time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
|
||||
"guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
|
||||
"guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
|
||||
"vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
|
||||
"vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
|
||||
"double_blocks": "transformer_blocks",
|
||||
"img_attn_q_norm": "attn.norm_q",
|
||||
"img_attn_k_norm": "attn.norm_k",
|
||||
"img_attn_proj": "attn.to_out.0",
|
||||
"txt_attn_q_norm": "attn.norm_added_q",
|
||||
"txt_attn_k_norm": "attn.norm_added_k",
|
||||
"txt_attn_proj": "attn.to_add_out",
|
||||
"img_mod.linear": "norm1.linear",
|
||||
"img_norm1": "norm1.norm",
|
||||
"img_norm2": "norm2",
|
||||
"img_mlp": "ff",
|
||||
"txt_mod.linear": "norm1_context.linear",
|
||||
"txt_norm1": "norm1.norm",
|
||||
"txt_norm2": "norm2_context",
|
||||
"txt_mlp": "ff_context",
|
||||
"self_attn_proj": "attn.to_out.0",
|
||||
"modulation.linear": "norm.linear",
|
||||
"pre_norm": "norm.norm",
|
||||
"final_layer.norm_final": "norm_out.norm",
|
||||
"final_layer.linear": "proj_out",
|
||||
"fc1": "net.0.proj",
|
||||
"fc2": "net.2",
|
||||
"input_embedder": "proj_in",
|
||||
}
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"txt_in": remap_txt_in_,
|
||||
"img_attn_qkv": remap_img_attn_qkv_,
|
||||
"txt_attn_qkv": remap_txt_attn_qkv_,
|
||||
"single_blocks": remap_single_transformer_blocks_,
|
||||
"final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
|
||||
}
|
||||
|
||||
def update_state_dict_(state_dict, old_key, new_key):
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
for key in list(checkpoint.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_(checkpoint, key, new_key)
|
||||
|
||||
for key in list(checkpoint.keys()):
|
||||
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, checkpoint)
|
||||
|
||||
return checkpoint
|
||||
|
||||
@@ -177,3 +177,5 @@ class FluxTransformer2DLoadersMixin:
|
||||
|
||||
self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
|
||||
self.config.encoder_hid_dim_type = "ip_image_proj"
|
||||
|
||||
self.to(dtype=self.dtype, device=self.device)
|
||||
|
||||
@@ -22,13 +22,14 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..activations import get_activation
|
||||
from ..embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import RMSNorm
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution
|
||||
|
||||
|
||||
class LTXCausalConv3d(nn.Module):
|
||||
class LTXVideoCausalConv3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
@@ -79,9 +80,9 @@ class LTXCausalConv3d(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXResnetBlock3d(nn.Module):
|
||||
class LTXVideoResnetBlock3d(nn.Module):
|
||||
r"""
|
||||
A 3D ResNet block used in the LTX model.
|
||||
A 3D ResNet block used in the LTXVideo model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`):
|
||||
@@ -109,7 +110,9 @@ class LTXResnetBlock3d(nn.Module):
|
||||
elementwise_affine: bool = False,
|
||||
non_linearity: str = "swish",
|
||||
is_causal: bool = True,
|
||||
):
|
||||
inject_noise: bool = False,
|
||||
timestep_conditioning: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
out_channels = out_channels or in_channels
|
||||
@@ -117,13 +120,13 @@ class LTXResnetBlock3d(nn.Module):
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
self.norm1 = RMSNorm(in_channels, eps=1e-8, elementwise_affine=elementwise_affine)
|
||||
self.conv1 = LTXCausalConv3d(
|
||||
self.conv1 = LTXVideoCausalConv3d(
|
||||
in_channels=in_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal
|
||||
)
|
||||
|
||||
self.norm2 = RMSNorm(out_channels, eps=1e-8, elementwise_affine=elementwise_affine)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.conv2 = LTXCausalConv3d(
|
||||
self.conv2 = LTXVideoCausalConv3d(
|
||||
in_channels=out_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal
|
||||
)
|
||||
|
||||
@@ -131,22 +134,58 @@ class LTXResnetBlock3d(nn.Module):
|
||||
self.conv_shortcut = None
|
||||
if in_channels != out_channels:
|
||||
self.norm3 = nn.LayerNorm(in_channels, eps=eps, elementwise_affine=True, bias=True)
|
||||
self.conv_shortcut = LTXCausalConv3d(
|
||||
self.conv_shortcut = LTXVideoCausalConv3d(
|
||||
in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, is_causal=is_causal
|
||||
)
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
self.per_channel_scale1 = None
|
||||
self.per_channel_scale2 = None
|
||||
if inject_noise:
|
||||
self.per_channel_scale1 = nn.Parameter(torch.zeros(in_channels, 1, 1))
|
||||
self.per_channel_scale2 = nn.Parameter(torch.zeros(in_channels, 1, 1))
|
||||
|
||||
self.scale_shift_table = None
|
||||
if timestep_conditioning:
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(4, in_channels) / in_channels**0.5)
|
||||
|
||||
def forward(
|
||||
self, inputs: torch.Tensor, temb: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None
|
||||
) -> torch.Tensor:
|
||||
hidden_states = inputs
|
||||
|
||||
hidden_states = self.norm1(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
|
||||
if self.scale_shift_table is not None:
|
||||
temb = temb.unflatten(1, (4, -1)) + self.scale_shift_table[None, ..., None, None, None]
|
||||
shift_1, scale_1, shift_2, scale_2 = temb.unbind(dim=1)
|
||||
hidden_states = hidden_states * (1 + scale_1) + shift_1
|
||||
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
|
||||
if self.per_channel_scale1 is not None:
|
||||
spatial_shape = hidden_states.shape[-2:]
|
||||
spatial_noise = torch.randn(
|
||||
spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype
|
||||
)[None]
|
||||
hidden_states = hidden_states + (spatial_noise * self.per_channel_scale1)[None, :, None, ...]
|
||||
|
||||
hidden_states = self.norm2(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
|
||||
if self.scale_shift_table is not None:
|
||||
hidden_states = hidden_states * (1 + scale_2) + shift_2
|
||||
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
if self.per_channel_scale2 is not None:
|
||||
spatial_shape = hidden_states.shape[-2:]
|
||||
spatial_noise = torch.randn(
|
||||
spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype
|
||||
)[None]
|
||||
hidden_states = hidden_states + (spatial_noise * self.per_channel_scale2)[None, :, None, ...]
|
||||
|
||||
if self.norm3 is not None:
|
||||
inputs = self.norm3(inputs.movedim(1, -1)).movedim(-1, 1)
|
||||
|
||||
@@ -157,20 +196,24 @@ class LTXResnetBlock3d(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXUpsampler3d(nn.Module):
|
||||
class LTXVideoUpsampler3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
is_causal: bool = True,
|
||||
residual: bool = False,
|
||||
upscale_factor: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
|
||||
self.residual = residual
|
||||
self.upscale_factor = upscale_factor
|
||||
|
||||
out_channels = in_channels * stride[0] * stride[1] * stride[2]
|
||||
out_channels = (in_channels * stride[0] * stride[1] * stride[2]) // upscale_factor
|
||||
|
||||
self.conv = LTXCausalConv3d(
|
||||
self.conv = LTXVideoCausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
@@ -181,6 +224,15 @@ class LTXUpsampler3d(nn.Module):
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
|
||||
if self.residual:
|
||||
residual = hidden_states.reshape(
|
||||
batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width
|
||||
)
|
||||
residual = residual.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
repeats = (self.stride[0] * self.stride[1] * self.stride[2]) // self.upscale_factor
|
||||
residual = residual.repeat(1, repeats, 1, 1, 1)
|
||||
residual = residual[:, :, self.stride[0] - 1 :]
|
||||
|
||||
hidden_states = self.conv(hidden_states)
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width
|
||||
@@ -188,12 +240,15 @@ class LTXUpsampler3d(nn.Module):
|
||||
hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
hidden_states = hidden_states[:, :, self.stride[0] - 1 :]
|
||||
|
||||
if self.residual:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXDownBlock3D(nn.Module):
|
||||
class LTXVideoDownBlock3D(nn.Module):
|
||||
r"""
|
||||
Down block used in the LTX model.
|
||||
Down block used in the LTXVideo model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`):
|
||||
@@ -235,7 +290,7 @@ class LTXDownBlock3D(nn.Module):
|
||||
resnets = []
|
||||
for _ in range(num_layers):
|
||||
resnets.append(
|
||||
LTXResnetBlock3d(
|
||||
LTXVideoResnetBlock3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
dropout=dropout,
|
||||
@@ -250,7 +305,7 @@ class LTXDownBlock3D(nn.Module):
|
||||
if spatio_temporal_scale:
|
||||
self.downsamplers = nn.ModuleList(
|
||||
[
|
||||
LTXCausalConv3d(
|
||||
LTXVideoCausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
kernel_size=3,
|
||||
@@ -262,7 +317,7 @@ class LTXDownBlock3D(nn.Module):
|
||||
|
||||
self.conv_out = None
|
||||
if in_channels != out_channels:
|
||||
self.conv_out = LTXResnetBlock3d(
|
||||
self.conv_out = LTXVideoResnetBlock3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
dropout=dropout,
|
||||
@@ -273,7 +328,12 @@ class LTXDownBlock3D(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""Forward method of the `LTXDownBlock3D` class."""
|
||||
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
@@ -285,24 +345,26 @@ class LTXDownBlock3D(nn.Module):
|
||||
|
||||
return create_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, generator
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states)
|
||||
hidden_states = resnet(hidden_states, temb, generator)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
if self.conv_out is not None:
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states, temb, generator)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d
|
||||
class LTXMidBlock3d(nn.Module):
|
||||
class LTXVideoMidBlock3d(nn.Module):
|
||||
r"""
|
||||
A middle block used in the LTX model.
|
||||
A middle block used in the LTXVideo model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`):
|
||||
@@ -329,28 +391,51 @@ class LTXMidBlock3d(nn.Module):
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_act_fn: str = "swish",
|
||||
is_causal: bool = True,
|
||||
inject_noise: bool = False,
|
||||
timestep_conditioning: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.time_embedder = None
|
||||
if timestep_conditioning:
|
||||
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0)
|
||||
|
||||
resnets = []
|
||||
for _ in range(num_layers):
|
||||
resnets.append(
|
||||
LTXResnetBlock3d(
|
||||
LTXVideoResnetBlock3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
dropout=dropout,
|
||||
eps=resnet_eps,
|
||||
non_linearity=resnet_act_fn,
|
||||
is_causal=is_causal,
|
||||
inject_noise=inject_noise,
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
)
|
||||
)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""Forward method of the `LTXMidBlock3D` class."""
|
||||
|
||||
if self.time_embedder is not None:
|
||||
temb = self.time_embedder(
|
||||
timestep=temb.flatten(),
|
||||
resolution=None,
|
||||
aspect_ratio=None,
|
||||
batch_size=hidden_states.size(0),
|
||||
hidden_dtype=hidden_states.dtype,
|
||||
)
|
||||
temb = temb.view(hidden_states.size(0), -1, 1, 1, 1)
|
||||
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
@@ -360,16 +445,18 @@ class LTXMidBlock3d(nn.Module):
|
||||
|
||||
return create_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, generator
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states)
|
||||
hidden_states = resnet(hidden_states, temb, generator)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXUpBlock3d(nn.Module):
|
||||
class LTXVideoUpBlock3d(nn.Module):
|
||||
r"""
|
||||
Up block used in the LTX model.
|
||||
Up block used in the LTXVideo model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`):
|
||||
@@ -403,45 +490,82 @@ class LTXUpBlock3d(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
spatio_temporal_scale: bool = True,
|
||||
is_causal: bool = True,
|
||||
inject_noise: bool = False,
|
||||
timestep_conditioning: bool = False,
|
||||
upsample_residual: bool = False,
|
||||
upscale_factor: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
out_channels = out_channels or in_channels
|
||||
|
||||
self.time_embedder = None
|
||||
if timestep_conditioning:
|
||||
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0)
|
||||
|
||||
self.conv_in = None
|
||||
if in_channels != out_channels:
|
||||
self.conv_in = LTXResnetBlock3d(
|
||||
self.conv_in = LTXVideoResnetBlock3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
dropout=dropout,
|
||||
eps=resnet_eps,
|
||||
non_linearity=resnet_act_fn,
|
||||
is_causal=is_causal,
|
||||
inject_noise=inject_noise,
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
)
|
||||
|
||||
self.upsamplers = None
|
||||
if spatio_temporal_scale:
|
||||
self.upsamplers = nn.ModuleList([LTXUpsampler3d(out_channels, stride=(2, 2, 2), is_causal=is_causal)])
|
||||
self.upsamplers = nn.ModuleList(
|
||||
[
|
||||
LTXVideoUpsampler3d(
|
||||
out_channels * upscale_factor,
|
||||
stride=(2, 2, 2),
|
||||
is_causal=is_causal,
|
||||
residual=upsample_residual,
|
||||
upscale_factor=upscale_factor,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
resnets = []
|
||||
for _ in range(num_layers):
|
||||
resnets.append(
|
||||
LTXResnetBlock3d(
|
||||
LTXVideoResnetBlock3d(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
dropout=dropout,
|
||||
eps=resnet_eps,
|
||||
non_linearity=resnet_act_fn,
|
||||
is_causal=is_causal,
|
||||
inject_noise=inject_noise,
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
)
|
||||
)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> torch.Tensor:
|
||||
if self.conv_in is not None:
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
hidden_states = self.conv_in(hidden_states, temb, generator)
|
||||
|
||||
if self.time_embedder is not None:
|
||||
temb = self.time_embedder(
|
||||
timestep=temb.flatten(),
|
||||
resolution=None,
|
||||
aspect_ratio=None,
|
||||
batch_size=hidden_states.size(0),
|
||||
hidden_dtype=hidden_states.dtype,
|
||||
)
|
||||
temb = temb.view(hidden_states.size(0), -1, 1, 1, 1)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
@@ -456,16 +580,18 @@ class LTXUpBlock3d(nn.Module):
|
||||
|
||||
return create_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, generator
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states)
|
||||
hidden_states = resnet(hidden_states, temb, generator)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXEncoder3d(nn.Module):
|
||||
class LTXVideoEncoder3d(nn.Module):
|
||||
r"""
|
||||
The `LTXEncoder3D` layer of a variational autoencoder that encodes input video samples to its latent
|
||||
The `LTXVideoEncoder3d` layer of a variational autoencoder that encodes input video samples to its latent
|
||||
representation.
|
||||
|
||||
Args:
|
||||
@@ -509,7 +635,7 @@ class LTXEncoder3d(nn.Module):
|
||||
|
||||
output_channel = block_out_channels[0]
|
||||
|
||||
self.conv_in = LTXCausalConv3d(
|
||||
self.conv_in = LTXVideoCausalConv3d(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=output_channel,
|
||||
kernel_size=3,
|
||||
@@ -524,7 +650,7 @@ class LTXEncoder3d(nn.Module):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i]
|
||||
|
||||
down_block = LTXDownBlock3D(
|
||||
down_block = LTXVideoDownBlock3D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
num_layers=layers_per_block[i],
|
||||
@@ -536,7 +662,7 @@ class LTXEncoder3d(nn.Module):
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid block
|
||||
self.mid_block = LTXMidBlock3d(
|
||||
self.mid_block = LTXVideoMidBlock3d(
|
||||
in_channels=output_channel,
|
||||
num_layers=layers_per_block[-1],
|
||||
resnet_eps=resnet_norm_eps,
|
||||
@@ -546,14 +672,14 @@ class LTXEncoder3d(nn.Module):
|
||||
# out
|
||||
self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = LTXCausalConv3d(
|
||||
self.conv_out = LTXVideoCausalConv3d(
|
||||
in_channels=output_channel, out_channels=out_channels + 1, kernel_size=3, stride=1, is_causal=is_causal
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
r"""The forward method of the `LTXEncoder3D` class."""
|
||||
r"""The forward method of the `LTXVideoEncoder3d` class."""
|
||||
|
||||
p = self.patch_size
|
||||
p_t = self.patch_size_t
|
||||
@@ -599,9 +725,10 @@ class LTXEncoder3d(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXDecoder3d(nn.Module):
|
||||
class LTXVideoDecoder3d(nn.Module):
|
||||
r"""
|
||||
The `LTXDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output sample.
|
||||
The `LTXVideoDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output
|
||||
sample.
|
||||
|
||||
Args:
|
||||
in_channels (`int`, defaults to 128):
|
||||
@@ -622,6 +749,8 @@ class LTXDecoder3d(nn.Module):
|
||||
Epsilon value for ResNet normalization layers.
|
||||
is_causal (`bool`, defaults to `False`):
|
||||
Whether this layer behaves causally (future frames depend only on past frames) or not.
|
||||
timestep_conditioning (`bool`, defaults to `False`):
|
||||
Whether to condition the model on timesteps.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -635,6 +764,10 @@ class LTXDecoder3d(nn.Module):
|
||||
patch_size_t: int = 1,
|
||||
resnet_norm_eps: float = 1e-6,
|
||||
is_causal: bool = False,
|
||||
inject_noise: Tuple[bool, ...] = (False, False, False, False),
|
||||
timestep_conditioning: bool = False,
|
||||
upsample_residual: Tuple[bool, ...] = (False, False, False, False),
|
||||
upsample_factor: Tuple[bool, ...] = (1, 1, 1, 1),
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -645,30 +778,42 @@ class LTXDecoder3d(nn.Module):
|
||||
block_out_channels = tuple(reversed(block_out_channels))
|
||||
spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling))
|
||||
layers_per_block = tuple(reversed(layers_per_block))
|
||||
inject_noise = tuple(reversed(inject_noise))
|
||||
upsample_residual = tuple(reversed(upsample_residual))
|
||||
upsample_factor = tuple(reversed(upsample_factor))
|
||||
output_channel = block_out_channels[0]
|
||||
|
||||
self.conv_in = LTXCausalConv3d(
|
||||
self.conv_in = LTXVideoCausalConv3d(
|
||||
in_channels=in_channels, out_channels=output_channel, kernel_size=3, stride=1, is_causal=is_causal
|
||||
)
|
||||
|
||||
self.mid_block = LTXMidBlock3d(
|
||||
in_channels=output_channel, num_layers=layers_per_block[0], resnet_eps=resnet_norm_eps, is_causal=is_causal
|
||||
self.mid_block = LTXVideoMidBlock3d(
|
||||
in_channels=output_channel,
|
||||
num_layers=layers_per_block[0],
|
||||
resnet_eps=resnet_norm_eps,
|
||||
is_causal=is_causal,
|
||||
inject_noise=inject_noise[0],
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
)
|
||||
|
||||
# up blocks
|
||||
num_block_out_channels = len(block_out_channels)
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
for i in range(num_block_out_channels):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
input_channel = output_channel // upsample_factor[i]
|
||||
output_channel = block_out_channels[i] // upsample_factor[i]
|
||||
|
||||
up_block = LTXUpBlock3d(
|
||||
up_block = LTXVideoUpBlock3d(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
num_layers=layers_per_block[i + 1],
|
||||
resnet_eps=resnet_norm_eps,
|
||||
spatio_temporal_scale=spatio_temporal_scaling[i],
|
||||
is_causal=is_causal,
|
||||
inject_noise=inject_noise[i + 1],
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
upsample_residual=upsample_residual[i],
|
||||
upscale_factor=upsample_factor[i],
|
||||
)
|
||||
|
||||
self.up_blocks.append(up_block)
|
||||
@@ -676,13 +821,20 @@ class LTXDecoder3d(nn.Module):
|
||||
# out
|
||||
self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = LTXCausalConv3d(
|
||||
self.conv_out = LTXVideoCausalConv3d(
|
||||
in_channels=output_channel, out_channels=self.out_channels, kernel_size=3, stride=1, is_causal=is_causal
|
||||
)
|
||||
|
||||
# timestep embedding
|
||||
self.time_embedder = None
|
||||
self.scale_shift_table = None
|
||||
if timestep_conditioning:
|
||||
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
@@ -693,17 +845,33 @@ class LTXDecoder3d(nn.Module):
|
||||
|
||||
return create_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), hidden_states)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block), hidden_states, temb
|
||||
)
|
||||
|
||||
for up_block in self.up_blocks:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), hidden_states)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), hidden_states, temb)
|
||||
else:
|
||||
hidden_states = self.mid_block(hidden_states)
|
||||
hidden_states = self.mid_block(hidden_states, temb)
|
||||
|
||||
for up_block in self.up_blocks:
|
||||
hidden_states = up_block(hidden_states)
|
||||
hidden_states = up_block(hidden_states, temb)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
|
||||
if self.time_embedder is not None:
|
||||
temb = self.time_embedder(
|
||||
timestep=temb.flatten(),
|
||||
resolution=None,
|
||||
aspect_ratio=None,
|
||||
batch_size=hidden_states.size(0),
|
||||
hidden_dtype=hidden_states.dtype,
|
||||
)
|
||||
temb = temb.view(hidden_states.size(0), -1, 1, 1, 1).unflatten(1, (2, -1))
|
||||
temb = temb + self.scale_shift_table[None, ..., None, None, None]
|
||||
shift, scale = temb.unbind(dim=1)
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
@@ -766,8 +934,15 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
out_channels: int = 3,
|
||||
latent_channels: int = 128,
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
|
||||
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
|
||||
decoder_layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
|
||||
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
|
||||
decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
|
||||
decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False, False),
|
||||
upsample_residual: Tuple[bool, ...] = (False, False, False, False),
|
||||
upsample_factor: Tuple[int, ...] = (1, 1, 1, 1),
|
||||
timestep_conditioning: bool = False,
|
||||
patch_size: int = 4,
|
||||
patch_size_t: int = 1,
|
||||
resnet_norm_eps: float = 1e-6,
|
||||
@@ -777,7 +952,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.encoder = LTXEncoder3d(
|
||||
self.encoder = LTXVideoEncoder3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=latent_channels,
|
||||
block_out_channels=block_out_channels,
|
||||
@@ -788,16 +963,20 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
resnet_norm_eps=resnet_norm_eps,
|
||||
is_causal=encoder_causal,
|
||||
)
|
||||
self.decoder = LTXDecoder3d(
|
||||
self.decoder = LTXVideoDecoder3d(
|
||||
in_channels=latent_channels,
|
||||
out_channels=out_channels,
|
||||
block_out_channels=block_out_channels,
|
||||
spatio_temporal_scaling=spatio_temporal_scaling,
|
||||
layers_per_block=layers_per_block,
|
||||
block_out_channels=decoder_block_out_channels,
|
||||
spatio_temporal_scaling=decoder_spatio_temporal_scaling,
|
||||
layers_per_block=decoder_layers_per_block,
|
||||
patch_size=patch_size,
|
||||
patch_size_t=patch_size_t,
|
||||
resnet_norm_eps=resnet_norm_eps,
|
||||
is_causal=decoder_causal,
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
inject_noise=decoder_inject_noise,
|
||||
upsample_residual=upsample_residual,
|
||||
upsample_factor=upsample_factor,
|
||||
)
|
||||
|
||||
latents_mean = torch.zeros((latent_channels,), requires_grad=False)
|
||||
@@ -837,7 +1016,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
self.tile_sample_stride_width = 448
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (LTXEncoder3d, LTXDecoder3d)):
|
||||
if isinstance(module, (LTXVideoEncoder3d, LTXVideoDecoder3d)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def enable_tiling(
|
||||
@@ -936,13 +1115,15 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
return (posterior,)
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def _decode(
|
||||
self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
batch_size, num_channels, num_frames, height, width = z.shape
|
||||
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
||||
tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
|
||||
|
||||
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
|
||||
return self.tiled_decode(z, return_dict=return_dict)
|
||||
return self.tiled_decode(z, temb, return_dict=return_dict)
|
||||
|
||||
if self.use_framewise_decoding:
|
||||
# TODO(aryan): requires investigation
|
||||
@@ -952,7 +1133,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
|
||||
)
|
||||
else:
|
||||
dec = self.decoder(z)
|
||||
dec = self.decoder(z, temb)
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
@@ -960,7 +1141,9 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def decode(
|
||||
self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
"""
|
||||
Decode a batch of images.
|
||||
|
||||
@@ -975,10 +1158,15 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
returned.
|
||||
"""
|
||||
if self.use_slicing and z.shape[0] > 1:
|
||||
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
||||
if temb is not None:
|
||||
decoded_slices = [
|
||||
self._decode(z_slice, t_slice).sample for z_slice, t_slice in (z.split(1), temb.split(1))
|
||||
]
|
||||
else:
|
||||
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
else:
|
||||
decoded = self._decode(z).sample
|
||||
decoded = self._decode(z, temb).sample
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
@@ -1060,7 +1248,9 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
|
||||
return enc
|
||||
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def tiled_decode(
|
||||
self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
r"""
|
||||
Decode a batch of images using a tiled decoder.
|
||||
|
||||
@@ -1101,7 +1291,9 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
|
||||
)
|
||||
else:
|
||||
time = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width])
|
||||
time = self.decoder(
|
||||
z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb
|
||||
)
|
||||
|
||||
row.append(time)
|
||||
rows.append(row)
|
||||
@@ -1129,6 +1321,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
@@ -1139,7 +1332,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
z = posterior.sample(generator=generator)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z)
|
||||
dec = self.decode(z, temb)
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
return dec
|
||||
|
||||
@@ -748,10 +748,10 @@ class CogVideoXPatchEmbed(nn.Module):
|
||||
pos_embedding = self._get_positional_embeddings(
|
||||
height, width, pre_time_compression_frames, device=embeds.device
|
||||
)
|
||||
pos_embedding = pos_embedding.to(dtype=embeds.dtype)
|
||||
else:
|
||||
pos_embedding = self.pos_embedding
|
||||
|
||||
pos_embedding = pos_embedding.to(dtype=embeds.dtype)
|
||||
embeds = embeds + pos_embedding
|
||||
|
||||
return embeds
|
||||
|
||||
@@ -228,7 +228,7 @@ def load_model_dict_into_meta(
|
||||
else:
|
||||
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
|
||||
raise ValueError(
|
||||
f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
|
||||
f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name].shape}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
|
||||
)
|
||||
|
||||
if is_quantized and (
|
||||
|
||||
@@ -718,10 +718,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
hf_quantizer = None
|
||||
|
||||
if hf_quantizer is not None:
|
||||
is_bnb_quantization_method = hf_quantizer.quantization_config.quant_method.value == "bitsandbytes"
|
||||
if is_bnb_quantization_method and device_map is not None:
|
||||
if device_map is not None:
|
||||
raise NotImplementedError(
|
||||
"Currently, `device_map` is automatically inferred for quantized bitsandbytes models. Support for providing `device_map` as an input will be added in the future."
|
||||
"Currently, providing `device_map` is not supported for quantized models. Providing `device_map` as an input will be added in the future."
|
||||
)
|
||||
|
||||
hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map)
|
||||
@@ -820,7 +819,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
revision=revision,
|
||||
subfolder=subfolder or "",
|
||||
)
|
||||
if hf_quantizer is not None and is_bnb_quantization_method:
|
||||
# TODO: https://github.com/huggingface/diffusers/issues/10013
|
||||
if hf_quantizer is not None:
|
||||
model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
|
||||
logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
|
||||
is_sharded = False
|
||||
|
||||
@@ -242,6 +242,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
patch_size: int = 1,
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-6,
|
||||
interpolation_scale: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -249,14 +250,14 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
# 1. Patch Embedding
|
||||
interpolation_scale = interpolation_scale if interpolation_scale is not None else max(sample_size // 64, 1)
|
||||
self.patch_embed = PatchEmbed(
|
||||
height=sample_size,
|
||||
width=sample_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dim=inner_dim,
|
||||
interpolation_scale=None,
|
||||
pos_embed_type=None,
|
||||
interpolation_scale=interpolation_scale,
|
||||
)
|
||||
|
||||
# 2. Additional condition embeddings
|
||||
|
||||
@@ -18,6 +18,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from diffusers.loaders import FromOriginalModelMixin
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
@@ -500,7 +502,7 @@ class HunyuanVideoTransformerBlock(nn.Module):
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
r"""
|
||||
A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ from ..normalization import AdaLayerNormSingle, RMSNorm
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class LTXAttentionProcessor2_0:
|
||||
class LTXVideoAttentionProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
||||
used in the LTX model. It applies a normalization layer and rotary embedding on the query and key vector.
|
||||
@@ -44,7 +44,7 @@ class LTXAttentionProcessor2_0:
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"LTXAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
"LTXVideoAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
@@ -92,7 +92,7 @@ class LTXAttentionProcessor2_0:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXRotaryPosEmbed(nn.Module):
|
||||
class LTXVideoRotaryPosEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
@@ -164,7 +164,7 @@ class LTXRotaryPosEmbed(nn.Module):
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class LTXTransformerBlock(nn.Module):
|
||||
class LTXVideoTransformerBlock(nn.Module):
|
||||
r"""
|
||||
Transformer block used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
|
||||
|
||||
@@ -208,7 +208,7 @@ class LTXTransformerBlock(nn.Module):
|
||||
cross_attention_dim=None,
|
||||
out_bias=attention_out_bias,
|
||||
qk_norm=qk_norm,
|
||||
processor=LTXAttentionProcessor2_0(),
|
||||
processor=LTXVideoAttentionProcessor2_0(),
|
||||
)
|
||||
|
||||
self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
@@ -221,7 +221,7 @@ class LTXTransformerBlock(nn.Module):
|
||||
bias=attention_bias,
|
||||
out_bias=attention_out_bias,
|
||||
qk_norm=qk_norm,
|
||||
processor=LTXAttentionProcessor2_0(),
|
||||
processor=LTXVideoAttentionProcessor2_0(),
|
||||
)
|
||||
|
||||
self.ff = FeedForward(dim, activation_fn=activation_fn)
|
||||
@@ -327,7 +327,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
|
||||
|
||||
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
||||
|
||||
self.rope = LTXRotaryPosEmbed(
|
||||
self.rope = LTXVideoRotaryPosEmbed(
|
||||
dim=inner_dim,
|
||||
base_num_frames=20,
|
||||
base_height=2048,
|
||||
@@ -339,7 +339,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
LTXTransformerBlock(
|
||||
LTXVideoTransformerBlock(
|
||||
dim=inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
|
||||
@@ -39,7 +39,7 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
|
||||
>>> from diffusers.utils import export_to_video
|
||||
|
||||
>>> model_id = "tencent/HunyuanVideo"
|
||||
>>> model_id = "hunyuanvideo-community/HunyuanVideo"
|
||||
>>> transformer = HunyuanVideoTransformer3DModel.from_pretrained(
|
||||
... model_id, subfolder="transformer", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
|
||||
@@ -193,15 +193,15 @@ class KandinskyCombinedPipeline(DiffusionPipeline):
|
||||
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
|
||||
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
|
||||
r"""
|
||||
Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗
|
||||
Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a
|
||||
GPU only when their specific submodule's `forward` method is called. Offloading happens on a submodule basis.
|
||||
Memory savings are higher than using `enable_model_cpu_offload`, but performance is lower.
|
||||
"""
|
||||
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
|
||||
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
|
||||
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
|
||||
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
|
||||
|
||||
def progress_bar(self, iterable=None, total=None):
|
||||
self.prior_pipe.progress_bar(iterable=iterable, total=total)
|
||||
@@ -411,7 +411,7 @@ class KandinskyImg2ImgCombinedPipeline(DiffusionPipeline):
|
||||
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
|
||||
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
||||
@@ -419,8 +419,8 @@ class KandinskyImg2ImgCombinedPipeline(DiffusionPipeline):
|
||||
Note that offloading happens on a submodule basis. Memory savings are higher than with
|
||||
`enable_model_cpu_offload`, but performance is lower.
|
||||
"""
|
||||
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
|
||||
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
|
||||
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
|
||||
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
|
||||
|
||||
def progress_bar(self, iterable=None, total=None):
|
||||
self.prior_pipe.progress_bar(iterable=iterable, total=total)
|
||||
@@ -652,7 +652,7 @@ class KandinskyInpaintCombinedPipeline(DiffusionPipeline):
|
||||
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
|
||||
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
||||
@@ -660,8 +660,8 @@ class KandinskyInpaintCombinedPipeline(DiffusionPipeline):
|
||||
Note that offloading happens on a submodule basis. Memory savings are higher than with
|
||||
`enable_model_cpu_offload`, but performance is lower.
|
||||
"""
|
||||
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
|
||||
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
|
||||
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
|
||||
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
|
||||
|
||||
def progress_bar(self, iterable=None, total=None):
|
||||
self.prior_pipe.progress_bar(iterable=iterable, total=total)
|
||||
|
||||
@@ -511,6 +511,8 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
decode_timestep: Union[float, List[float]] = 0.0,
|
||||
decode_noise_scale: Optional[Union[float, List[float]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
@@ -563,6 +565,10 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
|
||||
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
||||
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated attention mask for negative text embeddings.
|
||||
decode_timestep (`float`, defaults to `0.0`):
|
||||
The timestep at which generated video is decoded.
|
||||
decode_noise_scale (`float`, defaults to `None`):
|
||||
The interpolation factor between random noise and denoised latents at the decode timestep.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
@@ -753,7 +759,25 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
latents = latents.to(prompt_embeds.dtype)
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
if not self.vae.config.timestep_conditioning:
|
||||
timestep = None
|
||||
else:
|
||||
noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype)
|
||||
if not isinstance(decode_timestep, list):
|
||||
decode_timestep = [decode_timestep] * batch_size
|
||||
if decode_noise_scale is None:
|
||||
decode_noise_scale = decode_timestep
|
||||
elif not isinstance(decode_noise_scale, list):
|
||||
decode_noise_scale = [decode_noise_scale] * batch_size
|
||||
|
||||
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
|
||||
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
|
||||
:, None, None, None, None
|
||||
]
|
||||
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
|
||||
|
||||
video = self.vae.decode(latents, timestep, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
|
||||
@@ -571,6 +571,8 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
decode_timestep: Union[float, List[float]] = 0.0,
|
||||
decode_noise_scale: Optional[Union[float, List[float]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
@@ -625,6 +627,10 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
|
||||
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
||||
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated attention mask for negative text embeddings.
|
||||
decode_timestep (`float`, defaults to `0.0`):
|
||||
The timestep at which generated video is decoded.
|
||||
decode_noise_scale (`float`, defaults to `None`):
|
||||
The interpolation factor between random noise and denoised latents at the decode timestep.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
@@ -849,7 +855,25 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
latents = latents.to(prompt_embeds.dtype)
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
if not self.vae.config.timestep_conditioning:
|
||||
timestep = None
|
||||
else:
|
||||
noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype)
|
||||
if not isinstance(decode_timestep, list):
|
||||
decode_timestep = [decode_timestep] * batch_size
|
||||
if decode_noise_scale is None:
|
||||
decode_noise_scale = decode_timestep
|
||||
elif not isinstance(decode_noise_scale, list):
|
||||
decode_noise_scale = [decode_noise_scale] * batch_size
|
||||
|
||||
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
|
||||
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
|
||||
:, None, None, None, None
|
||||
]
|
||||
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
|
||||
|
||||
video = self.vae.decode(latents, timestep, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
|
||||
@@ -59,13 +59,13 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> from diffusers import SanaPAGPipeline
|
||||
|
||||
>>> pipe = SanaPAGPipeline.from_pretrained(
|
||||
... "Efficient-Large-Model/Sana_1600M_1024px_diffusers",
|
||||
... "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
|
||||
... pag_applied_layers=["transformer_blocks.8"],
|
||||
... torch_dtype=torch.float32,
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
>>> pipe.text_encoder.to(torch.bfloat16)
|
||||
>>> pipe.transformer = pipe.transformer.to(torch.float16)
|
||||
>>> pipe.transformer = pipe.transformer.to(torch.bfloat16)
|
||||
|
||||
>>> image = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"')[0]
|
||||
>>> image[0].save("output.png")
|
||||
|
||||
@@ -62,11 +62,11 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> from diffusers import SanaPipeline
|
||||
|
||||
>>> pipe = SanaPipeline.from_pretrained(
|
||||
... "Efficient-Large-Model/Sana_1600M_1024px_diffusers", torch_dtype=torch.float32
|
||||
... "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", torch_dtype=torch.float32
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
>>> pipe.text_encoder.to(torch.bfloat16)
|
||||
>>> pipe.transformer = pipe.transformer.to(torch.float16)
|
||||
>>> pipe.transformer = pipe.transformer.to(torch.bfloat16)
|
||||
|
||||
>>> image = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"')[0]
|
||||
>>> image[0].save("output.png")
|
||||
|
||||
@@ -23,7 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||
|
||||
from packaging import version
|
||||
|
||||
from ...utils import get_module_from_name, is_torch_available, is_torchao_available, logging
|
||||
from ...utils import get_module_from_name, is_torch_available, is_torch_version, is_torchao_available, logging
|
||||
from ..base import DiffusersQuantizer
|
||||
|
||||
|
||||
@@ -35,21 +35,28 @@ if is_torch_available():
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = (
|
||||
# At the moment, only int8 is supported for integer quantization dtypes.
|
||||
# In Torch 2.6, int1-int7 will be introduced, so this can be visited in the future
|
||||
# to support more quantization methods, such as intx_weight_only.
|
||||
torch.int8,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e5m2,
|
||||
torch.uint1,
|
||||
torch.uint2,
|
||||
torch.uint3,
|
||||
torch.uint4,
|
||||
torch.uint5,
|
||||
torch.uint6,
|
||||
torch.uint7,
|
||||
)
|
||||
if is_torch_version(">=", "2.5"):
|
||||
SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = (
|
||||
# At the moment, only int8 is supported for integer quantization dtypes.
|
||||
# In Torch 2.6, int1-int7 will be introduced, so this can be visited in the future
|
||||
# to support more quantization methods, such as intx_weight_only.
|
||||
torch.int8,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e5m2,
|
||||
torch.uint1,
|
||||
torch.uint2,
|
||||
torch.uint3,
|
||||
torch.uint4,
|
||||
torch.uint5,
|
||||
torch.uint6,
|
||||
torch.uint7,
|
||||
)
|
||||
else:
|
||||
SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = (
|
||||
torch.int8,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e5m2,
|
||||
)
|
||||
|
||||
if is_torchao_available():
|
||||
from torchao.quantization import quantize_
|
||||
@@ -93,6 +100,11 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
|
||||
raise ImportError(
|
||||
"Loading a TorchAO quantized model requires the torchao library. Please install with `pip install torchao`"
|
||||
)
|
||||
torchao_version = version.parse(importlib.metadata.version("torch"))
|
||||
if torchao_version < version.parse("0.7.0"):
|
||||
raise RuntimeError(
|
||||
f"The minimum required version of `torchao` is 0.7.0, but the current version is {torchao_version}. Please upgrade with `pip install -U torchao`."
|
||||
)
|
||||
|
||||
self.offload = False
|
||||
|
||||
@@ -120,7 +132,7 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
|
||||
def update_torch_dtype(self, torch_dtype):
|
||||
quant_type = self.quantization_config.quant_type
|
||||
|
||||
if quant_type.startswith("int"):
|
||||
if quant_type.startswith("int") or quant_type.startswith("uint"):
|
||||
if torch_dtype is not None and torch_dtype != torch.bfloat16:
|
||||
logger.warning(
|
||||
f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but "
|
||||
|
||||
@@ -490,11 +490,11 @@ def require_gguf_version_greater_or_equal(gguf_version):
|
||||
return decorator
|
||||
|
||||
|
||||
def require_torchao_version_greater(torchao_version):
|
||||
def require_torchao_version_greater_or_equal(torchao_version):
|
||||
def decorator(test_case):
|
||||
correct_torchao_version = is_torchao_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("torchao")).base_version
|
||||
) > version.parse(torchao_version)
|
||||
) >= version.parse(torchao_version)
|
||||
return unittest.skipUnless(
|
||||
correct_torchao_version, f"Test requires torchao with version greater than {torchao_version}."
|
||||
)(test_case)
|
||||
|
||||
@@ -15,8 +15,6 @@
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
@@ -29,16 +27,13 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
floats_tensor,
|
||||
is_torch_version,
|
||||
require_peft_backend,
|
||||
skip_mps,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
|
||||
from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
|
||||
from utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
@@ -123,41 +118,6 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
@skip_mps
|
||||
@pytest.mark.xfail(
|
||||
condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
|
||||
reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
|
||||
strict=True,
|
||||
)
|
||||
def test_lora_fuse_nan(self):
|
||||
for scheduler_cls in self.scheduler_classes:
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
|
||||
# corrupt one LoRA weight with `inf` values
|
||||
with torch.no_grad():
|
||||
pipe.transformer.transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf")
|
||||
|
||||
# with `safe_fusing=True` we should see an Error
|
||||
with self.assertRaises(ValueError):
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
|
||||
|
||||
# without we should not see an error, but every image will be black
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
|
||||
|
||||
out = pipe(
|
||||
"test", num_inference_steps=2, max_sequence_length=inputs["max_sequence_length"], output_type="np"
|
||||
)[0]
|
||||
|
||||
self.assertTrue(np.isnan(out).all())
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
|
||||
|
||||
|
||||
@@ -36,7 +36,6 @@ from diffusers.utils.testing_utils import (
|
||||
numpy_cosine_similarity_distance,
|
||||
require_big_gpu_with_torch_cuda,
|
||||
require_peft_backend,
|
||||
require_peft_version_greater,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
@@ -331,7 +330,8 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
}
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
|
||||
lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
@@ -340,85 +340,32 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
|
||||
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
|
||||
|
||||
@require_peft_version_greater("0.13.2")
|
||||
def test_lora_B_bias(self):
|
||||
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
# Testing opposite direction where the LoRA params are zero-padded.
|
||||
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# keep track of the bias values of the base layers to perform checks later.
|
||||
bias_values = {}
|
||||
for name, module in pipe.transformer.named_modules():
|
||||
if any(k in name for k in ["to_q", "to_k", "to_v", "to_out.0"]):
|
||||
if module.bias is not None:
|
||||
bias_values[name] = module.bias.data.clone()
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
denoiser_lora_config.lora_bias = False
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
lora_bias_false_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
pipe.delete_adapters("adapter-1")
|
||||
|
||||
denoiser_lora_config.lora_bias = True
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
lora_bias_true_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertFalse(np.allclose(original_output, lora_bias_false_output, atol=1e-3, rtol=1e-3))
|
||||
self.assertFalse(np.allclose(original_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))
|
||||
self.assertFalse(np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))
|
||||
|
||||
# for now this is flux control lora specific but can be generalized later and added to ./utils.py
|
||||
def test_correct_lora_configs_with_different_ranks(self):
|
||||
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
lora_output_same_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
pipe.transformer.delete_adapters("adapter-1")
|
||||
|
||||
# change the rank_pattern
|
||||
updated_rank = denoiser_lora_config.r * 2
|
||||
denoiser_lora_config.rank_pattern = {"single_transformer_blocks.0.attn.to_k": updated_rank}
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
assert pipe.transformer.peft_config["adapter-1"].rank_pattern == {
|
||||
"single_transformer_blocks.0.attn.to_k": updated_rank
|
||||
dummy_lora_A = torch.nn.Linear(1, rank, bias=False)
|
||||
dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
|
||||
lora_state_dict = {
|
||||
"transformer.x_embedder.lora_A.weight": dummy_lora_A.weight,
|
||||
"transformer.x_embedder.lora_B.weight": dummy_lora_B.weight,
|
||||
}
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||
|
||||
lora_output_diff_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
|
||||
self.assertTrue(not np.allclose(original_output, lora_output_same_rank, atol=1e-3, rtol=1e-3))
|
||||
self.assertTrue(not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3))
|
||||
pipe.transformer.delete_adapters("adapter-1")
|
||||
lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
# similarly change the alpha_pattern
|
||||
updated_alpha = denoiser_lora_config.lora_alpha * 2
|
||||
denoiser_lora_config.alpha_pattern = {"single_transformer_blocks.0.attn.to_k": updated_alpha}
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
assert pipe.transformer.peft_config["adapter-1"].alpha_pattern == {
|
||||
"single_transformer_blocks.0.attn.to_k": updated_alpha
|
||||
}
|
||||
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
|
||||
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
|
||||
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
|
||||
self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out)
|
||||
|
||||
lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
|
||||
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
|
||||
|
||||
def test_lora_expanding_shape_with_normal_lora(self):
|
||||
# This test checks if it works when a lora with expanded shapes (like control loras) but
|
||||
# another lora with correct shapes is loaded. The opposite direction isn't supported and is
|
||||
# tested with it.
|
||||
def test_normal_lora_with_expanded_lora_raises_error(self):
|
||||
# Test the following situation. Load a regular LoRA (such as the ones trained on Flux.1-Dev). And then
|
||||
# load shape expanded LoRA (such as Control LoRA).
|
||||
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
|
||||
# Change the transformer config to mimic a real use case.
|
||||
|
||||
@@ -15,8 +15,6 @@
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast
|
||||
|
||||
@@ -28,16 +26,14 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
floats_tensor,
|
||||
is_torch_version,
|
||||
require_peft_backend,
|
||||
skip_mps,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
|
||||
from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
|
||||
from utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
@@ -144,46 +140,6 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
@pytest.mark.xfail(
|
||||
condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
|
||||
reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
|
||||
strict=True,
|
||||
)
|
||||
def test_lora_fuse_nan(self):
|
||||
for scheduler_cls in self.scheduler_classes:
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
|
||||
# corrupt one LoRA weight with `inf` values
|
||||
with torch.no_grad():
|
||||
pipe.transformer.transformer_blocks[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
|
||||
|
||||
# with `safe_fusing=True` we should see an Error
|
||||
with self.assertRaises(ValueError):
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
|
||||
|
||||
# without we should not see an error, but every image will be black
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
|
||||
|
||||
out = pipe(
|
||||
prompt=inputs["prompt"],
|
||||
height=inputs["height"],
|
||||
width=inputs["width"],
|
||||
num_frames=inputs["num_frames"],
|
||||
num_inference_steps=inputs["num_inference_steps"],
|
||||
max_sequence_length=inputs["max_sequence_length"],
|
||||
output_type="np",
|
||||
)[0]
|
||||
|
||||
self.assertTrue(np.isnan(out).all())
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
|
||||
|
||||
|
||||
@@ -15,8 +15,6 @@
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
@@ -26,18 +24,12 @@ from diffusers import (
|
||||
LTXPipeline,
|
||||
LTXVideoTransformer3DModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
floats_tensor,
|
||||
is_torch_version,
|
||||
require_peft_backend,
|
||||
skip_mps,
|
||||
torch_device,
|
||||
)
|
||||
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
|
||||
from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
|
||||
from utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
@@ -60,10 +52,19 @@ class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
}
|
||||
transformer_cls = LTXVideoTransformer3DModel
|
||||
vae_kwargs = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 8,
|
||||
"block_out_channels": (8, 8, 8, 8),
|
||||
"spatio_temporal_scaling": (True, True, False, False),
|
||||
"decoder_block_out_channels": (8, 8, 8, 8),
|
||||
"layers_per_block": (1, 1, 1, 1, 1),
|
||||
"decoder_layers_per_block": (1, 1, 1, 1, 1),
|
||||
"spatio_temporal_scaling": (True, True, False, False),
|
||||
"decoder_spatio_temporal_scaling": (True, True, False, False),
|
||||
"decoder_inject_noise": (False, False, False, False, False),
|
||||
"upsample_residual": (False, False, False, False),
|
||||
"upsample_factor": (1, 1, 1, 1),
|
||||
"timestep_conditioning": False,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"encoder_causal": True,
|
||||
@@ -107,41 +108,6 @@ class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
@skip_mps
|
||||
@pytest.mark.xfail(
|
||||
condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
|
||||
reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
|
||||
strict=True,
|
||||
)
|
||||
def test_lora_fuse_nan(self):
|
||||
for scheduler_cls in self.scheduler_classes:
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
|
||||
# corrupt one LoRA weight with `inf` values
|
||||
with torch.no_grad():
|
||||
pipe.transformer.transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf")
|
||||
|
||||
# with `safe_fusing=True` we should see an Error
|
||||
with self.assertRaises(ValueError):
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
|
||||
|
||||
# without we should not see an error, but every image will be black
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
|
||||
|
||||
out = pipe(
|
||||
"test", num_inference_steps=2, max_sequence_length=inputs["max_sequence_length"], output_type="np"
|
||||
)[0]
|
||||
|
||||
self.assertTrue(np.isnan(out).all())
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
|
||||
|
||||
|
||||
@@ -15,24 +15,20 @@
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKLMochi, FlowMatchEulerDiscreteScheduler, MochiPipeline, MochiTransformer3DModel
|
||||
from diffusers.utils.testing_utils import (
|
||||
floats_tensor,
|
||||
is_torch_version,
|
||||
require_peft_backend,
|
||||
skip_mps,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
|
||||
from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
|
||||
from utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
@@ -103,40 +99,6 @@ class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
@pytest.mark.xfail(
|
||||
condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
|
||||
reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
|
||||
strict=True,
|
||||
)
|
||||
def test_lora_fuse_nan(self):
|
||||
for scheduler_cls in self.scheduler_classes:
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
|
||||
# corrupt one LoRA weight with `inf` values
|
||||
with torch.no_grad():
|
||||
pipe.transformer.transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf")
|
||||
|
||||
# with `safe_fusing=True` we should see an Error
|
||||
with self.assertRaises(ValueError):
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
|
||||
|
||||
# without we should not see an error, but every image will be black
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
|
||||
|
||||
out = pipe(
|
||||
"test", num_inference_steps=2, max_sequence_length=inputs["max_sequence_length"], output_type="np"
|
||||
)[0]
|
||||
|
||||
self.assertTrue(np.isnan(out).all())
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ import sys
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from transformers import Gemma2ForCausalLM, GemmaTokenizer
|
||||
from transformers import Gemma2Model, GemmaTokenizer
|
||||
|
||||
from diffusers import AutoencoderDC, FlowMatchEulerDiscreteScheduler, SanaPipeline, SanaTransformer2DModel
|
||||
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend
|
||||
@@ -73,7 +73,7 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
}
|
||||
vae_cls = AutoencoderDC
|
||||
tokenizer_cls, tokenizer_id = GemmaTokenizer, "hf-internal-testing/dummy-gemma"
|
||||
text_encoder_cls, text_encoder_id = Gemma2ForCausalLM, "hf-internal-testing/dummy-gemma-for-diffusers"
|
||||
text_encoder_cls, text_encoder_id = Gemma2Model, "hf-internal-testing/dummy-gemma-for-diffusers"
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
@@ -105,34 +105,34 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
@unittest.skip("Not supported in Sana.")
|
||||
@unittest.skip("Not supported in SANA.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Mochi.")
|
||||
@unittest.skip("Not supported in SANA.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Mochi.")
|
||||
@unittest.skip("Not supported in SANA.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
+112
-2
@@ -1528,7 +1528,7 @@ class PeftLoraLoaderMixinTests:
|
||||
@pytest.mark.xfail(
|
||||
condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
|
||||
reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
|
||||
strict=True,
|
||||
strict=False,
|
||||
)
|
||||
def test_lora_fuse_nan(self):
|
||||
for scheduler_cls in self.scheduler_classes:
|
||||
@@ -1568,7 +1568,7 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
# without we should not see an error, but every image will be black
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
|
||||
out = pipe("test", num_inference_steps=2, output_type="np")[0]
|
||||
out = pipe(**inputs)[0]
|
||||
|
||||
self.assertTrue(np.isnan(out).all())
|
||||
|
||||
@@ -1988,3 +1988,113 @@ class PeftLoraLoaderMixinTests:
|
||||
np.allclose(output_lora_scale_wo_kwargs, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
|
||||
"Loading from saved checkpoints should give same results as set_adapters().",
|
||||
)
|
||||
|
||||
@require_peft_version_greater("0.13.2")
|
||||
def test_lora_B_bias(self):
|
||||
# Currently, this test is only relevant for Flux Control LoRA as we are not
|
||||
# aware of any other LoRA checkpoint that has its `lora_B` biases trained.
|
||||
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# keep track of the bias values of the base layers to perform checks later.
|
||||
bias_values = {}
|
||||
denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer
|
||||
for name, module in denoiser.named_modules():
|
||||
if any(k in name for k in ["to_q", "to_k", "to_v", "to_out.0"]):
|
||||
if module.bias is not None:
|
||||
bias_values[name] = module.bias.data.clone()
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
denoiser_lora_config.lora_bias = False
|
||||
if self.unet_kwargs is not None:
|
||||
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
else:
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
lora_bias_false_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
pipe.delete_adapters("adapter-1")
|
||||
|
||||
denoiser_lora_config.lora_bias = True
|
||||
if self.unet_kwargs is not None:
|
||||
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
else:
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
lora_bias_true_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertFalse(np.allclose(original_output, lora_bias_false_output, atol=1e-3, rtol=1e-3))
|
||||
self.assertFalse(np.allclose(original_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))
|
||||
self.assertFalse(np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))
|
||||
|
||||
def test_correct_lora_configs_with_different_ranks(self):
|
||||
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
if self.unet_kwargs is not None:
|
||||
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
else:
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
|
||||
lora_output_same_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
if self.unet_kwargs is not None:
|
||||
pipe.unet.delete_adapters("adapter-1")
|
||||
else:
|
||||
pipe.transformer.delete_adapters("adapter-1")
|
||||
|
||||
denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer
|
||||
for name, _ in denoiser.named_modules():
|
||||
if "to_k" in name and "attn" in name and "lora" not in name:
|
||||
module_name_to_rank_update = name.replace(".base_layer.", ".")
|
||||
break
|
||||
|
||||
# change the rank_pattern
|
||||
updated_rank = denoiser_lora_config.r * 2
|
||||
denoiser_lora_config.rank_pattern = {module_name_to_rank_update: updated_rank}
|
||||
|
||||
if self.unet_kwargs is not None:
|
||||
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
updated_rank_pattern = pipe.unet.peft_config["adapter-1"].rank_pattern
|
||||
else:
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
updated_rank_pattern = pipe.transformer.peft_config["adapter-1"].rank_pattern
|
||||
|
||||
self.assertTrue(updated_rank_pattern == {module_name_to_rank_update: updated_rank})
|
||||
|
||||
lora_output_diff_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(not np.allclose(original_output, lora_output_same_rank, atol=1e-3, rtol=1e-3))
|
||||
self.assertTrue(not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3))
|
||||
|
||||
if self.unet_kwargs is not None:
|
||||
pipe.unet.delete_adapters("adapter-1")
|
||||
else:
|
||||
pipe.transformer.delete_adapters("adapter-1")
|
||||
|
||||
# similarly change the alpha_pattern
|
||||
updated_alpha = denoiser_lora_config.lora_alpha * 2
|
||||
denoiser_lora_config.alpha_pattern = {module_name_to_rank_update: updated_alpha}
|
||||
if self.unet_kwargs is not None:
|
||||
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
self.assertTrue(
|
||||
pipe.unet.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha}
|
||||
)
|
||||
else:
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
self.assertTrue(
|
||||
pipe.transformer.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha}
|
||||
)
|
||||
|
||||
lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
|
||||
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
|
||||
|
||||
@@ -0,0 +1,169 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import AutoencoderKLLTXVideo
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKLLTXVideo
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
def get_autoencoder_kl_ltx_video_config(self):
|
||||
return {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 8,
|
||||
"block_out_channels": (8, 8, 8, 8),
|
||||
"decoder_block_out_channels": (8, 8, 8, 8),
|
||||
"layers_per_block": (1, 1, 1, 1, 1),
|
||||
"decoder_layers_per_block": (1, 1, 1, 1, 1),
|
||||
"spatio_temporal_scaling": (True, True, False, False),
|
||||
"decoder_spatio_temporal_scaling": (True, True, False, False),
|
||||
"decoder_inject_noise": (False, False, False, False, False),
|
||||
"upsample_residual": (False, False, False, False),
|
||||
"upsample_factor": (1, 1, 1, 1),
|
||||
"timestep_conditioning": False,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
}
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 2
|
||||
num_frames = 9
|
||||
num_channels = 3
|
||||
sizes = (16, 16)
|
||||
|
||||
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
|
||||
|
||||
return {"sample": image}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 9, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 9, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_autoencoder_kl_ltx_video_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"LTXVideoEncoder3d",
|
||||
"LTXVideoDecoder3d",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoMidBlock3d",
|
||||
"LTXVideoUpBlock3d",
|
||||
}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@unittest.skip("Unsupported test.")
|
||||
def test_outputs_equivalence(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
|
||||
|
||||
class AutoencoderKLLTXVideo091Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKLLTXVideo
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
def get_autoencoder_kl_ltx_video_config(self):
|
||||
return {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 8,
|
||||
"block_out_channels": (8, 8, 8, 8),
|
||||
"decoder_block_out_channels": (16, 32, 64),
|
||||
"layers_per_block": (1, 1, 1, 1),
|
||||
"decoder_layers_per_block": (1, 1, 1, 1),
|
||||
"spatio_temporal_scaling": (True, True, True, False),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True),
|
||||
"decoder_inject_noise": (True, True, True, False),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": True,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
}
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 2
|
||||
num_frames = 9
|
||||
num_channels = 3
|
||||
sizes = (16, 16)
|
||||
|
||||
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
|
||||
timestep = torch.tensor([0.05] * batch_size, device=torch_device)
|
||||
|
||||
return {"sample": image, "temb": timestep}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 9, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 9, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_autoencoder_kl_ltx_video_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"LTXVideoEncoder3d",
|
||||
"LTXVideoDecoder3d",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoMidBlock3d",
|
||||
"LTXVideoUpBlock3d",
|
||||
}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@unittest.skip("Unsupported test.")
|
||||
def test_outputs_equivalence(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
@@ -2,10 +2,12 @@ import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor
|
||||
from diffusers.utils.testing_utils import torch_device
|
||||
|
||||
|
||||
class AttnAddedKVProcessorTests(unittest.TestCase):
|
||||
@@ -79,6 +81,15 @@ class AttnAddedKVProcessorTests(unittest.TestCase):
|
||||
|
||||
|
||||
class DeprecatedAttentionBlockTests(unittest.TestCase):
|
||||
@pytest.fixture(scope="session")
|
||||
def is_dist_enabled(pytestconfig):
|
||||
return pytestconfig.getoption("dist") == "loadfile"
|
||||
|
||||
@pytest.mark.xfail(
|
||||
condition=torch.device(torch_device).type == "cuda" and is_dist_enabled,
|
||||
reason="Test currently fails on our GPU CI because of `loadfile`. Note that it only fails when the tests are distributed from `pytest ... tests/models`. If the tests are run individually, even with `loadfile` it won't fail.",
|
||||
strict=True,
|
||||
)
|
||||
def test_conversion_when_using_device_map(self):
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
|
||||
|
||||
@@ -22,12 +22,14 @@ import traceback
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
import uuid
|
||||
from typing import Dict, List, Tuple
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import requests_mock
|
||||
import torch
|
||||
from accelerate.utils import compute_module_sizes
|
||||
import torch.nn as nn
|
||||
from accelerate.utils.modeling import _get_proper_dtype, dtype_byte_size
|
||||
from huggingface_hub import ModelCard, delete_repo, snapshot_download
|
||||
from huggingface_hub.utils import is_jinja_available
|
||||
from parameterized import parameterized
|
||||
@@ -113,6 +115,72 @@ def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
|
||||
out_queue.join()
|
||||
|
||||
|
||||
def named_persistent_module_tensors(
|
||||
module: nn.Module,
|
||||
recurse: bool = False,
|
||||
):
|
||||
"""
|
||||
A helper function that gathers all the tensors (parameters + persistent buffers) of a given module.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module we want the tensors on.
|
||||
recurse (`bool`, *optional`, defaults to `False`):
|
||||
Whether or not to go look in every submodule or just return the direct parameters and buffers.
|
||||
"""
|
||||
yield from module.named_parameters(recurse=recurse)
|
||||
|
||||
for named_buffer in module.named_buffers(recurse=recurse):
|
||||
name, _ = named_buffer
|
||||
# Get parent by splitting on dots and traversing the model
|
||||
parent = module
|
||||
if "." in name:
|
||||
parent_name = name.rsplit(".", 1)[0]
|
||||
for part in parent_name.split("."):
|
||||
parent = getattr(parent, part)
|
||||
name = name.split(".")[-1]
|
||||
if name not in parent._non_persistent_buffers_set:
|
||||
yield named_buffer
|
||||
|
||||
|
||||
def compute_module_persistent_sizes(
|
||||
model: nn.Module,
|
||||
dtype: Optional[Union[str, torch.device]] = None,
|
||||
special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None,
|
||||
):
|
||||
"""
|
||||
Compute the size of each submodule of a given model (parameters + persistent buffers).
|
||||
"""
|
||||
if dtype is not None:
|
||||
dtype = _get_proper_dtype(dtype)
|
||||
dtype_size = dtype_byte_size(dtype)
|
||||
if special_dtypes is not None:
|
||||
special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()}
|
||||
special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()}
|
||||
module_sizes = defaultdict(int)
|
||||
|
||||
module_list = []
|
||||
|
||||
module_list = named_persistent_module_tensors(model, recurse=True)
|
||||
|
||||
for name, tensor in module_list:
|
||||
if special_dtypes is not None and name in special_dtypes:
|
||||
size = tensor.numel() * special_dtypes_size[name]
|
||||
elif dtype is None:
|
||||
size = tensor.numel() * dtype_byte_size(tensor.dtype)
|
||||
elif str(tensor.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
|
||||
# According to the code in set_module_tensor_to_device, these types won't be converted
|
||||
# so use their original size here
|
||||
size = tensor.numel() * dtype_byte_size(tensor.dtype)
|
||||
else:
|
||||
size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype))
|
||||
name_parts = name.split(".")
|
||||
for idx in range(len(name_parts) + 1):
|
||||
module_sizes[".".join(name_parts[:idx])] += size
|
||||
|
||||
return module_sizes
|
||||
|
||||
|
||||
class ModelUtilsTest(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
@@ -1012,7 +1080,7 @@ class ModelTesterMixin:
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
# We test several splits of sizes to make sure it works.
|
||||
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
@@ -1042,7 +1110,7 @@ class ModelTesterMixin:
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir, safe_serialization=False)
|
||||
|
||||
@@ -1076,7 +1144,7 @@ class ModelTesterMixin:
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir)
|
||||
|
||||
@@ -1104,7 +1172,7 @@ class ModelTesterMixin:
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
# We test several splits of sizes to make sure it works.
|
||||
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
@@ -1132,7 +1200,7 @@ class ModelTesterMixin:
|
||||
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
|
||||
@@ -1164,7 +1232,7 @@ class ModelTesterMixin:
|
||||
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
|
||||
variant = "fp16"
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
@@ -1204,7 +1272,7 @@ class ModelTesterMixin:
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
|
||||
@@ -1233,7 +1301,7 @@ class ModelTesterMixin:
|
||||
config, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**config).eval()
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
|
||||
variant = "fp16"
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
|
||||
@@ -30,6 +30,8 @@ class MochiTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = MochiTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
# Overriding it because of the transformer size.
|
||||
model_split_percents = [0.7, 0.6, 0.6]
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers import SanaTransformer2DModel
|
||||
@@ -80,3 +81,27 @@ class SanaTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"SanaTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
condition=torch.device(torch_device).type == "cuda",
|
||||
reason="Test currently fails.",
|
||||
strict=True,
|
||||
)
|
||||
def test_cpu_offload(self):
|
||||
return super().test_cpu_offload()
|
||||
|
||||
@pytest.mark.xfail(
|
||||
condition=torch.device(torch_device).type == "cuda",
|
||||
reason="Test currently fails.",
|
||||
strict=True,
|
||||
)
|
||||
def test_disk_offload_with_safetensors(self):
|
||||
return super().test_disk_offload_with_safetensors()
|
||||
|
||||
@pytest.mark.xfail(
|
||||
condition=torch.device(torch_device).type == "cuda",
|
||||
reason="Test currently fails.",
|
||||
strict=True,
|
||||
)
|
||||
def test_disk_offload_without_safetensors(self):
|
||||
return super().test_disk_offload_without_safetensors()
|
||||
|
||||
@@ -63,10 +63,19 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLLTXVideo(
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
latent_channels=8,
|
||||
block_out_channels=(8, 8, 8, 8),
|
||||
spatio_temporal_scaling=(True, True, False, False),
|
||||
decoder_block_out_channels=(8, 8, 8, 8),
|
||||
layers_per_block=(1, 1, 1, 1, 1),
|
||||
decoder_layers_per_block=(1, 1, 1, 1, 1),
|
||||
spatio_temporal_scaling=(True, True, False, False),
|
||||
decoder_spatio_temporal_scaling=(True, True, False, False),
|
||||
decoder_inject_noise=(False, False, False, False, False),
|
||||
upsample_residual=(False, False, False, False),
|
||||
upsample_factor=(1, 1, 1, 1),
|
||||
timestep_conditioning=False,
|
||||
patch_size=1,
|
||||
patch_size_t=1,
|
||||
encoder_causal=True,
|
||||
|
||||
@@ -68,10 +68,19 @@ class LTXImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLLTXVideo(
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
latent_channels=8,
|
||||
block_out_channels=(8, 8, 8, 8),
|
||||
spatio_temporal_scaling=(True, True, False, False),
|
||||
decoder_block_out_channels=(8, 8, 8, 8),
|
||||
layers_per_block=(1, 1, 1, 1, 1),
|
||||
decoder_layers_per_block=(1, 1, 1, 1, 1),
|
||||
spatio_temporal_scaling=(True, True, False, False),
|
||||
decoder_spatio_temporal_scaling=(True, True, False, False),
|
||||
decoder_inject_noise=(False, False, False, False, False),
|
||||
upsample_residual=(False, False, False, False),
|
||||
upsample_factor=(1, 1, 1, 1),
|
||||
timestep_conditioning=False,
|
||||
patch_size=1,
|
||||
patch_size_t=1,
|
||||
encoder_causal=True,
|
||||
|
||||
@@ -18,7 +18,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Gemma2Config, Gemma2ForCausalLM, GemmaTokenizer
|
||||
from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer
|
||||
|
||||
from diffusers import AutoencoderDC, FlowMatchEulerDiscreteScheduler, SanaPipeline, SanaTransformer2DModel
|
||||
from diffusers.utils.testing_utils import (
|
||||
@@ -101,7 +101,7 @@ class SanaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = Gemma2Config(
|
||||
head_dim=16,
|
||||
hidden_size=32,
|
||||
hidden_size=8,
|
||||
initializer_range=0.02,
|
||||
intermediate_size=64,
|
||||
max_position_embeddings=8192,
|
||||
@@ -112,7 +112,7 @@ class SanaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
vocab_size=8,
|
||||
attn_implementation="eager",
|
||||
)
|
||||
text_encoder = Gemma2ForCausalLM(text_encoder_config)
|
||||
text_encoder = Gemma2Model(text_encoder_config)
|
||||
tokenizer = GemmaTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")
|
||||
|
||||
components = {
|
||||
|
||||
@@ -36,7 +36,7 @@ from diffusers.utils.testing_utils import (
|
||||
nightly,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torchao_version_greater,
|
||||
require_torchao_version_greater_or_equal,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@@ -74,13 +74,13 @@ if is_torch_available():
|
||||
|
||||
if is_torchao_available():
|
||||
from torchao.dtypes import AffineQuantizedTensor
|
||||
from torchao.dtypes.affine_quantized_tensor import TensorCoreTiledLayoutType
|
||||
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
|
||||
from torchao.utils import get_model_size_in_bytes
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torch_gpu
|
||||
@require_torchao_version_greater("0.6.0")
|
||||
@require_torchao_version_greater_or_equal("0.7.0")
|
||||
class TorchAoConfigTest(unittest.TestCase):
|
||||
def test_to_dict(self):
|
||||
"""
|
||||
@@ -125,25 +125,28 @@ class TorchAoConfigTest(unittest.TestCase):
|
||||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
||||
@require_torch
|
||||
@require_torch_gpu
|
||||
@require_torchao_version_greater("0.6.0")
|
||||
@require_torchao_version_greater_or_equal("0.7.0")
|
||||
class TorchAoTest(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_dummy_components(self, quantization_config: TorchAoConfig):
|
||||
model_id = "hf-internal-testing/tiny-flux-pipe"
|
||||
def get_dummy_components(
|
||||
self, quantization_config: TorchAoConfig, model_id: str = "hf-internal-testing/tiny-flux-pipe"
|
||||
):
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder")
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
|
||||
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained(
|
||||
model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16
|
||||
)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
|
||||
tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2")
|
||||
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae")
|
||||
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
return {
|
||||
@@ -209,10 +212,10 @@ class TorchAoTest(unittest.TestCase):
|
||||
"timestep": timestep,
|
||||
}
|
||||
|
||||
def _test_quant_type(self, quantization_config: TorchAoConfig, expected_slice: List[float]):
|
||||
components = self.get_dummy_components(quantization_config)
|
||||
def _test_quant_type(self, quantization_config: TorchAoConfig, expected_slice: List[float], model_id: str):
|
||||
components = self.get_dummy_components(quantization_config, model_id)
|
||||
pipe = FluxPipeline(**components)
|
||||
pipe.to(device=torch_device, dtype=torch.bfloat16)
|
||||
pipe.to(device=torch_device)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output = pipe(**inputs)[0]
|
||||
@@ -221,44 +224,45 @@ class TorchAoTest(unittest.TestCase):
|
||||
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
|
||||
|
||||
def test_quantization(self):
|
||||
# fmt: off
|
||||
QUANTIZATION_TYPES_TO_TEST = [
|
||||
("int4wo", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6445, 0.4336, 0.4531, 0.5625])),
|
||||
("int4dq", np.array([0.4688, 0.5195, 0.5547, 0.418, 0.4414, 0.6406, 0.4336, 0.4531, 0.5625])),
|
||||
("int8wo", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
|
||||
("int8dq", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
|
||||
("uint4wo", np.array([0.4609, 0.5234, 0.5508, 0.4199, 0.4336, 0.6406, 0.4316, 0.4531, 0.5625])),
|
||||
("uint7wo", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
|
||||
]
|
||||
for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
|
||||
# fmt: off
|
||||
QUANTIZATION_TYPES_TO_TEST = [
|
||||
("int4wo", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6445, 0.4336, 0.4531, 0.5625])),
|
||||
("int4dq", np.array([0.4688, 0.5195, 0.5547, 0.418, 0.4414, 0.6406, 0.4336, 0.4531, 0.5625])),
|
||||
("int8wo", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
|
||||
("int8dq", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
|
||||
("uint4wo", np.array([0.4609, 0.5234, 0.5508, 0.4199, 0.4336, 0.6406, 0.4316, 0.4531, 0.5625])),
|
||||
("uint7wo", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
|
||||
]
|
||||
|
||||
if TorchAoConfig._is_cuda_capability_atleast_8_9():
|
||||
QUANTIZATION_TYPES_TO_TEST.extend([
|
||||
("float8wo_e5m2", np.array([0.4590, 0.5273, 0.5547, 0.4219, 0.4375, 0.6406, 0.4316, 0.4512, 0.5625])),
|
||||
("float8wo_e4m3", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6406, 0.4316, 0.4531, 0.5625])),
|
||||
# =====
|
||||
# The following lead to an internal torch error:
|
||||
# RuntimeError: mat2 shape (32x4 must be divisible by 16
|
||||
# Skip these for now; TODO(aryan): investigate later
|
||||
# ("float8dq_e4m3", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
|
||||
# ("float8dq_e4m3_tensor", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
|
||||
# =====
|
||||
# Cutlass fails to initialize for below
|
||||
# ("float8dq_e4m3_row", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
|
||||
# =====
|
||||
("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
|
||||
("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
|
||||
])
|
||||
# fmt: on
|
||||
if TorchAoConfig._is_cuda_capability_atleast_8_9():
|
||||
QUANTIZATION_TYPES_TO_TEST.extend([
|
||||
("float8wo_e5m2", np.array([0.4590, 0.5273, 0.5547, 0.4219, 0.4375, 0.6406, 0.4316, 0.4512, 0.5625])),
|
||||
("float8wo_e4m3", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6406, 0.4316, 0.4531, 0.5625])),
|
||||
# =====
|
||||
# The following lead to an internal torch error:
|
||||
# RuntimeError: mat2 shape (32x4 must be divisible by 16
|
||||
# Skip these for now; TODO(aryan): investigate later
|
||||
# ("float8dq_e4m3", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
|
||||
# ("float8dq_e4m3_tensor", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
|
||||
# =====
|
||||
# Cutlass fails to initialize for below
|
||||
# ("float8dq_e4m3_row", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
|
||||
# =====
|
||||
("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
|
||||
("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
|
||||
])
|
||||
# fmt: on
|
||||
|
||||
for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST:
|
||||
quant_kwargs = {}
|
||||
if quantization_name in ["uint4wo", "uint7wo"]:
|
||||
# The dummy flux model that we use has smaller dimensions. This imposes some restrictions on group_size here
|
||||
quant_kwargs.update({"group_size": 16})
|
||||
quantization_config = TorchAoConfig(
|
||||
quant_type=quantization_name, modules_to_not_convert=["x_embedder"], **quant_kwargs
|
||||
)
|
||||
self._test_quant_type(quantization_config, expected_slice)
|
||||
for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST:
|
||||
quant_kwargs = {}
|
||||
if quantization_name in ["uint4wo", "uint7wo"]:
|
||||
# The dummy flux model that we use has smaller dimensions. This imposes some restrictions on group_size here
|
||||
quant_kwargs.update({"group_size": 16})
|
||||
quantization_config = TorchAoConfig(
|
||||
quant_type=quantization_name, modules_to_not_convert=["x_embedder"], **quant_kwargs
|
||||
)
|
||||
self._test_quant_type(quantization_config, expected_slice, model_id)
|
||||
|
||||
def test_int4wo_quant_bfloat16_conversion(self):
|
||||
"""
|
||||
@@ -276,15 +280,16 @@ class TorchAoTest(unittest.TestCase):
|
||||
self.assertTrue(isinstance(weight, AffineQuantizedTensor))
|
||||
self.assertEqual(weight.quant_min, 0)
|
||||
self.assertEqual(weight.quant_max, 15)
|
||||
self.assertTrue(isinstance(weight.layout_type, TensorCoreTiledLayoutType))
|
||||
|
||||
def test_device_map(self):
|
||||
# Note: We were not checking if the weight tensor's were AffineQuantizedTensor's before. If we did
|
||||
# it would have errored out. Now, we do. So, device_map basically never worked with or without
|
||||
# sharded checkpoints. This will need to be supported in the future (TODO(aryan))
|
||||
"""
|
||||
Test if the quantized model int4 weight-only is working properly with "auto" and custom device maps.
|
||||
The custom device map performs cpu/disk offloading as well. Also verifies that the device map is
|
||||
correctly set (in the `hf_device_map` attribute of the model).
|
||||
"""
|
||||
|
||||
custom_device_map_dict = {
|
||||
"time_text_embed": torch_device,
|
||||
"context_embedder": torch_device,
|
||||
@@ -296,51 +301,73 @@ class TorchAoTest(unittest.TestCase):
|
||||
}
|
||||
device_maps = ["auto", custom_device_map_dict]
|
||||
|
||||
inputs = self.get_dummy_tensor_inputs(torch_device)
|
||||
expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375])
|
||||
# inputs = self.get_dummy_tensor_inputs(torch_device)
|
||||
# expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375])
|
||||
|
||||
for device_map in device_maps:
|
||||
device_map_to_compare = {"": 0} if device_map == "auto" else device_map
|
||||
# device_map_to_compare = {"": 0} if device_map == "auto" else device_map
|
||||
|
||||
# Test non-sharded model
|
||||
with tempfile.TemporaryDirectory() as offload_folder:
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
|
||||
quantized_model = FluxTransformer2DModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-pipe",
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
device_map=device_map,
|
||||
torch_dtype=torch.bfloat16,
|
||||
offload_folder=offload_folder,
|
||||
)
|
||||
# Test non-sharded model - should work
|
||||
with self.assertRaises(NotImplementedError):
|
||||
with tempfile.TemporaryDirectory() as offload_folder:
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
|
||||
_ = FluxTransformer2DModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-pipe",
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
device_map=device_map,
|
||||
torch_dtype=torch.bfloat16,
|
||||
offload_folder=offload_folder,
|
||||
)
|
||||
|
||||
self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)
|
||||
# weight = quantized_model.transformer_blocks[0].ff.net[2].weight
|
||||
# self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)
|
||||
# self.assertTrue(isinstance(weight, AffineQuantizedTensor))
|
||||
|
||||
output = quantized_model(**inputs)[0]
|
||||
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
|
||||
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
|
||||
# output = quantized_model(**inputs)[0]
|
||||
# output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
|
||||
# self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
|
||||
|
||||
# Test sharded model
|
||||
with tempfile.TemporaryDirectory() as offload_folder:
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
|
||||
quantized_model = FluxTransformer2DModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-sharded",
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
device_map=device_map,
|
||||
torch_dtype=torch.bfloat16,
|
||||
offload_folder=offload_folder,
|
||||
)
|
||||
# Test sharded model - should not work
|
||||
with self.assertRaises(NotImplementedError):
|
||||
with tempfile.TemporaryDirectory() as offload_folder:
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
|
||||
_ = FluxTransformer2DModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-sharded",
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
device_map=device_map,
|
||||
torch_dtype=torch.bfloat16,
|
||||
offload_folder=offload_folder,
|
||||
)
|
||||
|
||||
self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)
|
||||
# weight = quantized_model.transformer_blocks[0].ff.net[2].weight
|
||||
# self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)
|
||||
# self.assertTrue(isinstance(weight, AffineQuantizedTensor))
|
||||
|
||||
output = quantized_model(**inputs)[0]
|
||||
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
|
||||
# output = quantized_model(**inputs)[0]
|
||||
# output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
|
||||
|
||||
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
|
||||
# self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
|
||||
|
||||
def test_modules_to_not_convert(self):
|
||||
quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"])
|
||||
quantized_model_with_not_convert = FluxTransformer2DModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-pipe",
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
unquantized_layer = quantized_model_with_not_convert.transformer_blocks[0].ff.net[2]
|
||||
self.assertTrue(isinstance(unquantized_layer, torch.nn.Linear))
|
||||
self.assertFalse(isinstance(unquantized_layer.weight, AffineQuantizedTensor))
|
||||
self.assertEqual(unquantized_layer.weight.dtype, torch.bfloat16)
|
||||
|
||||
quantized_layer = quantized_model_with_not_convert.proj_out
|
||||
self.assertTrue(isinstance(quantized_layer.weight, AffineQuantizedTensor))
|
||||
|
||||
quantization_config = TorchAoConfig("int8_weight_only")
|
||||
quantized_model = FluxTransformer2DModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-pipe",
|
||||
subfolder="transformer",
|
||||
@@ -348,14 +375,10 @@ class TorchAoTest(unittest.TestCase):
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
unquantized_layer = quantized_model.transformer_blocks[0].ff.net[2]
|
||||
self.assertTrue(isinstance(unquantized_layer, torch.nn.Linear))
|
||||
self.assertFalse(isinstance(unquantized_layer.weight, AffineQuantizedTensor))
|
||||
self.assertEqual(unquantized_layer.weight.dtype, torch.bfloat16)
|
||||
size_quantized_with_not_convert = get_model_size_in_bytes(quantized_model_with_not_convert)
|
||||
size_quantized = get_model_size_in_bytes(quantized_model)
|
||||
|
||||
quantized_layer = quantized_model.proj_out
|
||||
self.assertTrue(isinstance(quantized_layer.weight, AffineQuantizedTensor))
|
||||
self.assertEqual(quantized_layer.weight.layout_tensor.data.dtype, torch.int8)
|
||||
self.assertTrue(size_quantized < size_quantized_with_not_convert)
|
||||
|
||||
def test_training(self):
|
||||
quantization_config = TorchAoConfig("int8_weight_only")
|
||||
@@ -391,83 +414,82 @@ class TorchAoTest(unittest.TestCase):
|
||||
@nightly
|
||||
def test_torch_compile(self):
|
||||
r"""Test that verifies if torch.compile works with torchao quantization."""
|
||||
quantization_config = TorchAoConfig("int8_weight_only")
|
||||
components = self.get_dummy_components(quantization_config)
|
||||
pipe = FluxPipeline(**components)
|
||||
pipe.to(device=torch_device, dtype=torch.bfloat16)
|
||||
for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
|
||||
quantization_config = TorchAoConfig("int8_weight_only")
|
||||
components = self.get_dummy_components(quantization_config, model_id=model_id)
|
||||
pipe = FluxPipeline(**components)
|
||||
pipe.to(device=torch_device)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
normal_output = pipe(**inputs)[0].flatten()[-32:]
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
normal_output = pipe(**inputs)[0].flatten()[-32:]
|
||||
|
||||
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True, dynamic=False)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
compile_output = pipe(**inputs)[0].flatten()[-32:]
|
||||
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True, dynamic=False)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
compile_output = pipe(**inputs)[0].flatten()[-32:]
|
||||
|
||||
# Note: Seems to require higher tolerance
|
||||
self.assertTrue(np.allclose(normal_output, compile_output, atol=1e-2, rtol=1e-3))
|
||||
|
||||
@staticmethod
|
||||
def _get_memory_footprint(module):
|
||||
quantized_param_memory = 0.0
|
||||
unquantized_param_memory = 0.0
|
||||
|
||||
for param in module.parameters():
|
||||
if param.__class__.__name__ == "AffineQuantizedTensor":
|
||||
data, scale, zero_point = param.layout_tensor.get_plain()
|
||||
quantized_param_memory += data.numel() + data.element_size()
|
||||
quantized_param_memory += scale.numel() + scale.element_size()
|
||||
quantized_param_memory += zero_point.numel() + zero_point.element_size()
|
||||
else:
|
||||
unquantized_param_memory += param.data.numel() * param.data.element_size()
|
||||
|
||||
total_memory = quantized_param_memory + unquantized_param_memory
|
||||
return total_memory, quantized_param_memory, unquantized_param_memory
|
||||
# Note: Seems to require higher tolerance
|
||||
self.assertTrue(np.allclose(normal_output, compile_output, atol=1e-2, rtol=1e-3))
|
||||
|
||||
def test_memory_footprint(self):
|
||||
r"""
|
||||
A simple test to check if the model conversion has been done correctly by checking on the
|
||||
memory footprint of the converted model and the class type of the linear layers of the converted models
|
||||
"""
|
||||
transformer_int4wo = self.get_dummy_components(TorchAoConfig("int4wo"))["transformer"]
|
||||
transformer_int4wo_gs32 = self.get_dummy_components(TorchAoConfig("int4wo", group_size=32))["transformer"]
|
||||
transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"))["transformer"]
|
||||
transformer_bf16 = self.get_dummy_components(None)["transformer"]
|
||||
for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
|
||||
transformer_int4wo = self.get_dummy_components(TorchAoConfig("int4wo"), model_id=model_id)["transformer"]
|
||||
transformer_int4wo_gs32 = self.get_dummy_components(
|
||||
TorchAoConfig("int4wo", group_size=32), model_id=model_id
|
||||
)["transformer"]
|
||||
transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"), model_id=model_id)["transformer"]
|
||||
transformer_bf16 = self.get_dummy_components(None, model_id=model_id)["transformer"]
|
||||
|
||||
total_int4wo, quantized_int4wo, unquantized_int4wo = self._get_memory_footprint(transformer_int4wo)
|
||||
total_int4wo_gs32, quantized_int4wo_gs32, unquantized_int4wo_gs32 = self._get_memory_footprint(
|
||||
transformer_int4wo_gs32
|
||||
)
|
||||
total_int8wo, quantized_int8wo, unquantized_int8wo = self._get_memory_footprint(transformer_int8wo)
|
||||
total_bf16, quantized_bf16, unquantized_bf16 = self._get_memory_footprint(transformer_bf16)
|
||||
# Will not quantized all the layers by default due to the model weights shapes not being divisible by group_size=64
|
||||
for block in transformer_int4wo.transformer_blocks:
|
||||
self.assertTrue(isinstance(block.ff.net[2].weight, AffineQuantizedTensor))
|
||||
self.assertTrue(isinstance(block.ff_context.net[2].weight, AffineQuantizedTensor))
|
||||
|
||||
self.assertTrue(quantized_bf16 == 0 and total_bf16 == unquantized_bf16)
|
||||
# int4wo_gs32 has smaller group size, so more groups -> more scales and zero points
|
||||
self.assertTrue(total_int8wo < total_bf16 < total_int4wo_gs32)
|
||||
# int4 with default group size quantized very few linear layers compared to a smaller group size of 32
|
||||
self.assertTrue(quantized_int4wo < quantized_int4wo_gs32 and unquantized_int4wo > unquantized_int4wo_gs32)
|
||||
# int8 quantizes more layers compare to int4 with default group size
|
||||
self.assertTrue(quantized_int8wo < quantized_int4wo)
|
||||
# Will quantize all the linear layers except x_embedder
|
||||
for name, module in transformer_int4wo_gs32.named_modules():
|
||||
if isinstance(module, nn.Linear) and name not in ["x_embedder"]:
|
||||
self.assertTrue(isinstance(module.weight, AffineQuantizedTensor))
|
||||
|
||||
# Will quantize all the linear layers
|
||||
for module in transformer_int8wo.modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
self.assertTrue(isinstance(module.weight, AffineQuantizedTensor))
|
||||
|
||||
total_int4wo = get_model_size_in_bytes(transformer_int4wo)
|
||||
total_int4wo_gs32 = get_model_size_in_bytes(transformer_int4wo_gs32)
|
||||
total_int8wo = get_model_size_in_bytes(transformer_int8wo)
|
||||
total_bf16 = get_model_size_in_bytes(transformer_bf16)
|
||||
|
||||
# TODO: refactor to align with other quantization tests
|
||||
# Latter has smaller group size, so more groups -> more scales and zero points
|
||||
self.assertTrue(total_int4wo < total_int4wo_gs32)
|
||||
# int8 quantizes more layers compare to int4 with default group size
|
||||
self.assertTrue(total_int8wo < total_int4wo)
|
||||
# int4wo does not quantize too many layers because of default group size, but for the layers it does
|
||||
# there is additional overhead of scales and zero points
|
||||
self.assertTrue(total_bf16 < total_int4wo)
|
||||
|
||||
def test_wrong_config(self):
|
||||
with self.assertRaises(ValueError):
|
||||
self.get_dummy_components(TorchAoConfig("int42"))
|
||||
|
||||
|
||||
# This class is not to be run as a test by itself. See the tests that follow this class
|
||||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
||||
@require_torch
|
||||
@require_torch_gpu
|
||||
@require_torchao_version_greater("0.6.0")
|
||||
@require_torchao_version_greater_or_equal("0.7.0")
|
||||
class TorchAoSerializationTest(unittest.TestCase):
|
||||
model_name = "hf-internal-testing/tiny-flux-pipe"
|
||||
quant_method, quant_method_kwargs = None, None
|
||||
device = "cuda"
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_dummy_model(self, device=None):
|
||||
quantization_config = TorchAoConfig(self.quant_method, **self.quant_method_kwargs)
|
||||
def get_dummy_model(self, quant_method, quant_method_kwargs, device=None):
|
||||
quantization_config = TorchAoConfig(quant_method, **quant_method_kwargs)
|
||||
quantized_model = FluxTransformer2DModel.from_pretrained(
|
||||
self.model_name,
|
||||
subfolder="transformer",
|
||||
@@ -503,21 +525,23 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
"timestep": timestep,
|
||||
}
|
||||
|
||||
def test_original_model_expected_slice(self):
|
||||
quantized_model = self.get_dummy_model(torch_device)
|
||||
def _test_original_model_expected_slice(self, quant_method, quant_method_kwargs, expected_slice):
|
||||
quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, torch_device)
|
||||
inputs = self.get_dummy_tensor_inputs(torch_device)
|
||||
output = quantized_model(**inputs)[0]
|
||||
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
|
||||
self.assertTrue(np.allclose(output_slice, self.expected_slice, atol=1e-3, rtol=1e-3))
|
||||
weight = quantized_model.transformer_blocks[0].ff.net[2].weight
|
||||
self.assertTrue(isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)))
|
||||
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
|
||||
|
||||
def check_serialization_expected_slice(self, expected_slice):
|
||||
quantized_model = self.get_dummy_model(self.device)
|
||||
def _check_serialization_expected_slice(self, quant_method, quant_method_kwargs, expected_slice, device):
|
||||
quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, device)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
quantized_model.save_pretrained(tmp_dir, safe_serialization=False)
|
||||
loaded_quantized_model = FluxTransformer2DModel.from_pretrained(
|
||||
tmp_dir, torch_dtype=torch.bfloat16, device_map=torch_device, use_safetensors=False
|
||||
)
|
||||
tmp_dir, torch_dtype=torch.bfloat16, use_safetensors=False
|
||||
).to(device=torch_device)
|
||||
|
||||
inputs = self.get_dummy_tensor_inputs(torch_device)
|
||||
output = loaded_quantized_model(**inputs)[0]
|
||||
@@ -530,42 +554,39 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
)
|
||||
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
|
||||
|
||||
def test_serialization_expected_slice(self):
|
||||
self.check_serialization_expected_slice(self.serialized_expected_slice)
|
||||
def test_int_a8w8_cuda(self):
|
||||
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
|
||||
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
|
||||
device = "cuda"
|
||||
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
|
||||
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
|
||||
|
||||
def test_int_a16w8_cuda(self):
|
||||
quant_method, quant_method_kwargs = "int8_weight_only", {}
|
||||
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
|
||||
device = "cuda"
|
||||
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
|
||||
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
|
||||
|
||||
class TorchAoSerializationINTA8W8Test(TorchAoSerializationTest):
|
||||
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
|
||||
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
|
||||
serialized_expected_slice = expected_slice
|
||||
device = "cuda"
|
||||
def test_int_a8w8_cpu(self):
|
||||
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
|
||||
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
|
||||
device = "cpu"
|
||||
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
|
||||
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
|
||||
|
||||
|
||||
class TorchAoSerializationINTA16W8Test(TorchAoSerializationTest):
|
||||
quant_method, quant_method_kwargs = "int8_weight_only", {}
|
||||
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
|
||||
serialized_expected_slice = expected_slice
|
||||
device = "cuda"
|
||||
|
||||
|
||||
class TorchAoSerializationINTA8W8CPUTest(TorchAoSerializationTest):
|
||||
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
|
||||
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
|
||||
serialized_expected_slice = expected_slice
|
||||
device = "cpu"
|
||||
|
||||
|
||||
class TorchAoSerializationINTA16W8CPUTest(TorchAoSerializationTest):
|
||||
quant_method, quant_method_kwargs = "int8_weight_only", {}
|
||||
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
|
||||
serialized_expected_slice = expected_slice
|
||||
device = "cpu"
|
||||
def test_int_a16w8_cpu(self):
|
||||
quant_method, quant_method_kwargs = "int8_weight_only", {}
|
||||
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
|
||||
device = "cpu"
|
||||
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
|
||||
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
|
||||
|
||||
|
||||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
||||
@require_torch
|
||||
@require_torch_gpu
|
||||
@require_torchao_version_greater("0.6.0")
|
||||
@require_torchao_version_greater_or_equal("0.7.0")
|
||||
@slow
|
||||
@nightly
|
||||
class SlowTorchAoTests(unittest.TestCase):
|
||||
@@ -574,18 +595,25 @@ class SlowTorchAoTests(unittest.TestCase):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_dummy_components(self, quantization_config: TorchAoConfig):
|
||||
# This is just for convenience, so that we can modify it at one place for custom environments and locally testing
|
||||
cache_dir = None
|
||||
model_id = "black-forest-labs/FLUX.1-dev"
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder")
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
|
||||
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
|
||||
tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2")
|
||||
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae")
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16, cache_dir=cache_dir
|
||||
)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained(
|
||||
model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16, cache_dir=cache_dir
|
||||
)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer", cache_dir=cache_dir)
|
||||
tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2", cache_dir=cache_dir)
|
||||
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16, cache_dir=cache_dir)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
return {
|
||||
@@ -617,13 +645,15 @@ class SlowTorchAoTests(unittest.TestCase):
|
||||
|
||||
def _test_quant_type(self, quantization_config, expected_slice):
|
||||
components = self.get_dummy_components(quantization_config)
|
||||
pipe = FluxPipeline(**components).to(dtype=torch.bfloat16)
|
||||
pipe = FluxPipeline(**components)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
weight = pipe.transformer.transformer_blocks[0].ff.net[2].weight
|
||||
self.assertTrue(isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)))
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output = pipe(**inputs)[0].flatten()
|
||||
output_slice = np.concatenate((output[:16], output[-16:]))
|
||||
|
||||
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
|
||||
|
||||
def test_quantization(self):
|
||||
@@ -636,7 +666,7 @@ class SlowTorchAoTests(unittest.TestCase):
|
||||
if TorchAoConfig._is_cuda_capability_atleast_8_9():
|
||||
QUANTIZATION_TYPES_TO_TEST.extend([
|
||||
("float8wo_e4m3", np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])),
|
||||
("fp5_e3m1", np.array([0.0527, 0.0742, 0.1289, 0.0449, 0.0625, 0.1308, 0.0585, 0.0742, 0.1269, 0.0585, 0.0722, 0.1328, 0.0566, 0.0742, 0.1347, 0.0585, 0.3691, 0.7578, 0.5429, 0.4355, 0.7695, 0.5546, 0.4414, 0.7578, 0.5468, 0.4179, 0.7265, 0.5273, 0.3945, 0.6992, 0.5234, 0.4316])),
|
||||
("fp5_e3m1", np.array([0.0527, 0.0762, 0.1309, 0.0449, 0.0645, 0.1328, 0.0566, 0.0723, 0.125, 0.0566, 0.0703, 0.1328, 0.0566, 0.0742, 0.1348, 0.0566, 0.3633, 0.7617, 0.5273, 0.4277, 0.7891, 0.5469, 0.4375, 0.8008, 0.5586, 0.4336, 0.7383, 0.5156, 0.3906, 0.6992, 0.5156, 0.4375])),
|
||||
])
|
||||
# fmt: on
|
||||
|
||||
@@ -646,3 +676,125 @@ class SlowTorchAoTests(unittest.TestCase):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def test_serialization_int8wo(self):
|
||||
quantization_config = TorchAoConfig("int8wo")
|
||||
components = self.get_dummy_components(quantization_config)
|
||||
pipe = FluxPipeline(**components)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
weight = pipe.transformer.x_embedder.weight
|
||||
self.assertTrue(isinstance(weight, AffineQuantizedTensor))
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output = pipe(**inputs)[0].flatten()[:128]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipe.transformer.save_pretrained(tmp_dir, safe_serialization=False)
|
||||
pipe.remove_all_hooks()
|
||||
del pipe.transformer
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
tmp_dir, torch_dtype=torch.bfloat16, use_safetensors=False
|
||||
)
|
||||
pipe.transformer = transformer
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
weight = transformer.x_embedder.weight
|
||||
self.assertTrue(isinstance(weight, AffineQuantizedTensor))
|
||||
|
||||
loaded_output = pipe(**inputs)[0].flatten()[:128]
|
||||
# Seems to require higher tolerance depending on which machine it is being run.
|
||||
# A difference of 0.06 in normalized pixel space (-1 to 1), corresponds to a difference of
|
||||
# 0.06 / 2 * 255 = 7.65 in pixel space (0 to 255). On our CI runners, the difference is about 0.04,
|
||||
# on DGX it is 0.06, and on audace it is 0.037. So, we are using a tolerance of 0.06 here.
|
||||
self.assertTrue(np.allclose(output, loaded_output, atol=0.06))
|
||||
|
||||
def test_memory_footprint_int4wo(self):
|
||||
# The original checkpoints are in bf16 and about 24 GB
|
||||
expected_memory_in_gb = 6.0
|
||||
quantization_config = TorchAoConfig("int4wo")
|
||||
cache_dir = None
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
int4wo_memory_in_gb = get_model_size_in_bytes(transformer) / 1024**3
|
||||
self.assertTrue(int4wo_memory_in_gb < expected_memory_in_gb)
|
||||
|
||||
def test_memory_footprint_int8wo(self):
|
||||
# The original checkpoints are in bf16 and about 24 GB
|
||||
expected_memory_in_gb = 12.0
|
||||
quantization_config = TorchAoConfig("int8wo")
|
||||
cache_dir = None
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
int8wo_memory_in_gb = get_model_size_in_bytes(transformer) / 1024**3
|
||||
self.assertTrue(int8wo_memory_in_gb < expected_memory_in_gb)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torch_gpu
|
||||
@require_torchao_version_greater_or_equal("0.7.0")
|
||||
@slow
|
||||
@nightly
|
||||
class SlowTorchAoPreserializedModelTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_dummy_inputs(self, device: torch.device, seed: int = 0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator().manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "an astronaut riding a horse in space",
|
||||
"height": 512,
|
||||
"width": 512,
|
||||
"num_inference_steps": 20,
|
||||
"output_type": "np",
|
||||
"generator": generator,
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_transformer_int8wo(self):
|
||||
# fmt: off
|
||||
expected_slice = np.array([0.0566, 0.0781, 0.1426, 0.0488, 0.0684, 0.1504, 0.0625, 0.0781, 0.1445, 0.0625, 0.0781, 0.1562, 0.0547, 0.0723, 0.1484, 0.0566, 0.5703, 0.8867, 0.7266, 0.5742, 0.875, 0.7148, 0.5586, 0.875, 0.7148, 0.5547, 0.8633, 0.7109, 0.5469, 0.8398, 0.6992, 0.5703])
|
||||
# fmt: on
|
||||
|
||||
# This is just for convenience, so that we can modify it at one place for custom environments and locally testing
|
||||
cache_dir = None
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
"hf-internal-testing/FLUX.1-Dev-TorchAO-int8wo-transformer",
|
||||
torch_dtype=torch.bfloat16,
|
||||
use_safetensors=False,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16, cache_dir=cache_dir
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
# Verify that all linear layer weights are quantized
|
||||
for name, module in pipe.transformer.named_modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
self.assertTrue(isinstance(module.weight, AffineQuantizedTensor))
|
||||
|
||||
# Verify outputs match expected slice
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output = pipe(**inputs)[0].flatten()
|
||||
output_slice = np.concatenate((output[:16], output[-16:]))
|
||||
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
|
||||
|
||||
Reference in New Issue
Block a user