Compare commits

...

25 Commits

Author SHA1 Message Date
DN6 5b413d949d update 2025-03-28 13:31:49 +01:00
Dhruv Nair e793adc465 update 2025-03-17 15:22:49 +01:00
Sayak Paul 100142586f [CI] pin transformers version for benchmarking. (#11067)
pin transformers version for benchmarking.
2025-03-16 10:27:35 +05:30
Yuxuan Zhang 82188cef04 CogView4 Control Block (#10809)
* cogview4 control training


---------

Co-authored-by: OleehyO <leehy0357@gmail.com>
Co-authored-by: yiyixuxu <yixu310@gmail.com>
2025-03-15 07:15:56 -10:00
Sayak Paul cc19726f3d [Tests] add requires peft decorator. (#11037)
* add requires peft decorator.

* install peft conditionally.

* conditional deps.

Co-authored-by: DN6 <dhruv.nair@gmail.com>

---------

Co-authored-by: DN6 <dhruv.nair@gmail.com>
2025-03-15 12:56:41 +05:30
Dimitri Barbot be54a95b93 Fix deterministic issue when getting pipeline dtype and device (#10696)
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2025-03-15 07:50:58 +05:30
Juan Acevedo 6b9a3334db reverts accidental change that removes attn_mask in attn. Improves fl… (#11065)
reverts accidental change that removes attn_mask in attn. Improves flux ptxla by using flash block sizes. Moves encoding outside the for loop.

Co-authored-by: Juan Acevedo <jfacevedo@google.com>
2025-03-14 12:47:01 -10:00
Andreas Jörg 8ead643bb7 [examples/controlnet/train_controlnet_sd3.py] Fixes #11050 - Cast prompt_embeds and pooled_prompt_embeds to weight_dtype to prevent dtype mismatch (#11051)
Fix: dtype mismatch of prompt embeddings in sd3 controlnet training

Co-authored-by: Andreas Jörg <andreasjoerg@MacBook-Pro-von-Andreas-2.fritz.box>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-03-14 17:33:15 +05:30
Sayak Paul 124ac3e81f [LoRA] feat: support non-diffusers wan t2v loras. (#11059)
feat: support non-diffusers wan t2v loras.
2025-03-14 16:01:25 +05:30
Sayak Paul 2f0f281b0d [Tests] restrict memory tests for quanto for certain schemes. (#11052)
* restrict memory tests for quanto for certain schemes.

* Apply suggestions from code review

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* fixes

* style

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2025-03-14 10:35:19 +05:30
ZhengKai91 ccc8321651 Fix aclnnRepeatInterleaveIntWithDim error on NPU for get_1d_rotary_pos_embed (#10820)
* get_1d_rotary_pos_embed support npu

* Update src/diffusers/models/embeddings.py

---------

Co-authored-by: Kai zheng <kaizheng@KaideMacBook-Pro.local>
Co-authored-by: hlky <hlky@hlky.ac>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2025-03-13 09:58:03 -10:00
Yaniv Galron 5e48cd27d4 making ``formatted_images`` initialization compact (#10801)
compact writing

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2025-03-13 09:27:14 -10:00
hlky 5551506b29 Rename Lumina(2)Text2ImgPipeline -> Lumina(2)Pipeline (#10827)
* Rename Lumina(2)Text2ImgPipeline -> Lumina(2)Pipeline


---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
2025-03-13 09:24:21 -10:00
Sayak Paul 20e4b6a628 [LoRA] change to warning from info when notifying the users about a LoRA no-op (#11044)
* move to warning.

* test related changes.
2025-03-12 21:20:48 +05:30
hlky 4ea9f89b8e Wan Pipeline scaling fix, type hint warning, multi generator fix (#11007)
* Wan Pipeline scaling fix, type hint warning, multi generator fix

* Apply suggestions from code review
2025-03-12 12:05:52 +00:00
hlky 733b44ac82 [hybrid inference 🍯🐝] Add VAE encode (#11017)
* [hybrid inference 🍯🐝] Add VAE encode

* _toctree: add vae encode

* Add endpoints, tests

* vae_encode docs

* vae encode benchmarks

* api reference

* changelog

* Update docs/source/en/hybrid_inference/overview.md

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-03-12 11:23:41 +00:00
hlky 8b4f8ba764 Use output_size in repeat_interleave (#11030) 2025-03-12 07:30:21 +00:00
Dhruv Nair 5428046437 [Refactor] Clean up import utils boilerplate (#11026)
* update

* update

* update
2025-03-12 07:48:34 +05:30
39th president of the United States, probably e7ffeae0a1 Fix for multi-GPU WAN inference (#10997)
Ensure that hidden_state and shift/scale are on the same device when running with multiple GPUs

Co-authored-by: Jimmy <39@🇺🇸.com>
2025-03-11 07:42:12 -10:00
CyberVy d87ce2cefc Fix missing **kwargs in lora_pipeline.py (#11011)
* Update lora_pipeline.py

* Apply style fixes

* fix-copies

---------

Co-authored-by: hlky <hlky@hlky.ac>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-03-11 07:34:27 -10:00
wonderfan 36d0553af2 chore: fix help messages in advanced diffusion examples (#10923) 2025-03-11 07:33:55 -10:00
hlky 7e0db46f73 Fix SD3 IPAdapter feature extractor (#11027) 2025-03-11 16:29:27 +00:00
Sayak Paul e4b056fe65 [LoRA] support wan i2v loras from the world. (#11025)
* support wan i2v loras from the world.

* remove copied from.

* upates

* add lora.
2025-03-11 20:43:29 +05:30
Eliseu Silva 4e3ddd5afa fix: mixture tiling sdxl pipeline - adjust gerating time_ids & embeddings (#11012)
small fix on generating time_ids & embeddings
2025-03-11 04:20:18 -03:00
Dhruv Nair 9add071592 [Quantization] Allow loading TorchAO serialized Tensor objects with torch>=2.6 (#11018)
* update

* update

* update

* update

* update

* update

* update

* update

* update
2025-03-11 10:52:01 +05:30
76 changed files with 4000 additions and 538 deletions
+1
View File
@@ -38,6 +38,7 @@ jobs:
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
python -m uv pip install pandas peft
python -m uv pip uninstall transformers && python -m uv pip install transformers==4.48.0
- name: Environment
run: |
python utils/print_env.py
+7
View File
@@ -414,12 +414,16 @@ jobs:
config:
- backend: "bitsandbytes"
test_location: "bnb"
additional_deps: ["peft"]
- backend: "gguf"
test_location: "gguf"
additional_deps: []
- backend: "torchao"
test_location: "torchao"
additional_deps: []
- backend: "optimum_quanto"
test_location: "quanto"
additional_deps: []
runs-on:
group: aws-g6e-xlarge-plus
container:
@@ -437,6 +441,9 @@ jobs:
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
python -m uv pip install -U ${{ matrix.config.backend }}
if [ "${{ join(matrix.config.additional_deps, ' ') }}" != "" ]; then
python -m uv pip install ${{ join(matrix.config.additional_deps, ' ') }}
fi
python -m uv pip install pytest-reportlog
- name: Environment
run: |
+2
View File
@@ -81,6 +81,8 @@
title: Overview
- local: hybrid_inference/vae_decode
title: VAE Decode
- local: hybrid_inference/vae_encode
title: VAE Encode
- local: hybrid_inference/api_reference
title: API Reference
title: Hybrid Inference
+7 -7
View File
@@ -58,10 +58,10 @@ Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fa
First, load the pipeline:
```python
from diffusers import LuminaText2ImgPipeline
from diffusers import LuminaPipeline
import torch
pipeline = LuminaText2ImgPipeline.from_pretrained(
pipeline = LuminaPipeline.from_pretrained(
"Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16
).to("cuda")
```
@@ -86,11 +86,11 @@ image = pipeline(prompt="Upper body of a young woman in a Victorian-era outfit w
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`LuminaText2ImgPipeline`] for inference with bitsandbytes.
Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`LuminaPipeline`] for inference with bitsandbytes.
```py
import torch
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, Transformer2DModel, LuminaText2ImgPipeline
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, Transformer2DModel, LuminaPipeline
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel
quant_config = BitsAndBytesConfig(load_in_8bit=True)
@@ -109,7 +109,7 @@ transformer_8bit = Transformer2DModel.from_pretrained(
torch_dtype=torch.float16,
)
pipeline = LuminaText2ImgPipeline.from_pretrained(
pipeline = LuminaPipeline.from_pretrained(
"Alpha-VLLM/Lumina-Next-SFT-diffusers",
text_encoder=text_encoder_8bit,
transformer=transformer_8bit,
@@ -122,9 +122,9 @@ image = pipeline(prompt).images[0]
image.save("lumina.png")
```
## LuminaText2ImgPipeline
## LuminaPipeline
[[autodoc]] LuminaText2ImgPipeline
[[autodoc]] LuminaPipeline
- all
- __call__
+6 -6
View File
@@ -36,14 +36,14 @@ Single file loading for Lumina Image 2.0 is available for the `Lumina2Transforme
```python
import torch
from diffusers import Lumina2Transformer2DModel, Lumina2Text2ImgPipeline
from diffusers import Lumina2Transformer2DModel, Lumina2Pipeline
ckpt_path = "https://huggingface.co/Alpha-VLLM/Lumina-Image-2.0/blob/main/consolidated.00-of-01.pth"
transformer = Lumina2Transformer2DModel.from_single_file(
ckpt_path, torch_dtype=torch.bfloat16
)
pipe = Lumina2Text2ImgPipeline.from_pretrained(
pipe = Lumina2Pipeline.from_pretrained(
"Alpha-VLLM/Lumina-Image-2.0", transformer=transformer, torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload()
@@ -60,7 +60,7 @@ image.save("lumina-single-file.png")
GGUF Quantized checkpoints for the `Lumina2Transformer2DModel` can be loaded via `from_single_file` with the `GGUFQuantizationConfig`
```python
from diffusers import Lumina2Transformer2DModel, Lumina2Text2ImgPipeline, GGUFQuantizationConfig
from diffusers import Lumina2Transformer2DModel, Lumina2Pipeline, GGUFQuantizationConfig
ckpt_path = "https://huggingface.co/calcuis/lumina-gguf/blob/main/lumina2-q4_0.gguf"
transformer = Lumina2Transformer2DModel.from_single_file(
@@ -69,7 +69,7 @@ transformer = Lumina2Transformer2DModel.from_single_file(
torch_dtype=torch.bfloat16,
)
pipe = Lumina2Text2ImgPipeline.from_pretrained(
pipe = Lumina2Pipeline.from_pretrained(
"Alpha-VLLM/Lumina-Image-2.0", transformer=transformer, torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload()
@@ -80,8 +80,8 @@ image = pipe(
image.save("lumina-gguf.png")
```
## Lumina2Text2ImgPipeline
## Lumina2Pipeline
[[autodoc]] Lumina2Text2ImgPipeline
[[autodoc]] Lumina2Pipeline
- all
- __call__
+359 -12
View File
@@ -14,22 +14,365 @@
# Wan
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
</div>
[Wan 2.1](https://github.com/Wan-Video/Wan2.1) by the Alibaba Wan Team.
<!-- TODO(aryan): update abstract once paper is out -->
## Generating Videos with Wan 2.1
We will first need to install some addtional dependencies.
```shell
pip install -u ftfy imageio-ffmpeg imageio
```
### Text to Video Generation
The following example requires 11GB VRAM to run and uses the smaller `Wan-AI/Wan2.1-T2V-1.3B-Diffusers` model. You can switch it out
for the larger `Wan2.1-I2V-14B-720P-Diffusers` or `Wan-AI/Wan2.1-I2V-14B-480P-Diffusers` if you have at least 35GB VRAM available.
```python
from diffusers import WanPipeline
from diffusers.utils import export_to_video
# Available models: Wan-AI/Wan2.1-I2V-14B-720P-Diffusers or Wan-AI/Wan2.1-I2V-14B-480P-Diffusers
model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
pipe = WanPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()
prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
num_frames = 33
frames = pipe(prompt=prompt, negative_prompt=negative_prompt, num_frames=num_frames).frames[0]
export_to_video(frames, "wan-t2v.mp4", fps=16)
```
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
You can improve the quality of the generated video by running the decoding step in full precision.
</Tip>
Recommendations for inference:
- VAE in `torch.float32` for better decoding quality.
- `num_frames` should be of the form `4 * k + 1`, for example `49` or `81`.
- For smaller resolution videos, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution videos, try higher values (between `7.0` and `12.0`). The default value is `3.0` for Wan.
```python
from diffusers import WanPipeline, AutoencoderKLWan
from diffusers.utils import export_to_video
### Using a custom scheduler
model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
# replace this with pipe.to("cuda") if you have sufficient VRAM
pipe.enable_model_cpu_offload()
prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
num_frames = 33
frames = pipe(prompt=prompt, num_frames=num_frames).frames[0]
export_to_video(frames, "wan-t2v.mp4", fps=16)
```
### Image to Video Generation
The Image to Video pipeline requires loading the `AutoencoderKLWan` and the `CLIPVisionModel` components in full precision. The following example will need at least
35GB of VRAM to run.
```python
import torch
import numpy as np
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
from diffusers.utils import export_to_video, load_image
from transformers import CLIPVisionModel
# Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
image_encoder = CLIPVisionModel.from_pretrained(
model_id, subfolder="image_encoder", torch_dtype=torch.float32
)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanImageToVideoPipeline.from_pretrained(
model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
)
# replace this with pipe.to("cuda") if you have sufficient VRAM
pipe.enable_model_cpu_offload()
image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
)
max_area = 480 * 832
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
prompt = (
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
"the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
)
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
num_frames = 33
output = pipe(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
guidance_scale=5.0,
).frames[0]
export_to_video(output, "wan-i2v.mp4", fps=16)
```
## Memory Optimizations for Wan 2.1
Base inference with the large 14B Wan 2.1 models can take up to 35GB of VRAM when generating videos at 720p resolution. We'll outline a few memory optimizations we can apply to reduce the VRAM required to run the model.
We'll use `Wan-AI/Wan2.1-I2V-14B-720P-Diffusers` model in these examples to demonstrate the memory savings, but the techniques are applicable to all model checkpoints.
### Group Offloading the Transformer and UMT5 Text Encoder
Find more information about group offloading [here](../optimization/memory.md)
#### Block Level Group Offloading
We can reduce our VRAM requirements by applying group offloading to the larger model components of the pipeline; the `WanTransformer3DModel` and `UMT5EncoderModel`. Group offloading will break up the individual modules of a model and offload/onload them onto your GPU as needed during inference. In this example, we'll apply `block_level` offloading, which will group the modules in a model into blocks of size `num_blocks_per_group` and offload/onload them to GPU. Moving to between CPU and GPU does add latency to the inference process. You can trade off between latency and memory savings by increasing or decreasing the `num_blocks_per_group`.
The following example will now only require 14GB of VRAM to run, but will take approximately 30 minutes to generate a video.
```python
import torch
import numpy as np
from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanImageToVideoPipeline
from diffusers.hooks.group_offloading import apply_group_offloading
from diffusers.utils import export_to_video, load_image
from transformers import UMT5EncoderModel, CLIPVisionModel
# Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
image_encoder = CLIPVisionModel.from_pretrained(
model_id, subfolder="image_encoder", torch_dtype=torch.float32
)
text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
onload_device = torch.device("cuda")
offload_device = torch.device("cpu")
apply_group_offloading(text_encoder,
onload_device=onload_device,
offload_device=offload_device,
offload_type="block_level",
num_blocks_per_group=4
)
transformer.enable_group_offload(
onload_device=onload_device,
offload_device=offload_device,
offload_type="block_level",
num_blocks_per_group=4,
)
pipe = WanImageToVideoPipeline.from_pretrained(
model_id,
vae=vae,
transformer=transformer,
text_encoder=text_encoder,
image_encoder=image_encoder,
torch_dtype=torch.bfloat16
)
# Since we've offloaded the larger models alrady, we can move the rest of the model components to GPU
pipe.to("cuda")
image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
)
max_area = 720 * 832
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
prompt = (
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
"the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
)
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
num_frames = 33
output = pipe(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
guidance_scale=5.0,
).frames[0]
export_to_video(output, "wan-i2v.mp4", fps=16)
```
#### Block Level Group Offloading with CUDA Streams
We can speed up group offloading inference, by enabling the use of [CUDA streams](https://pytorch.org/docs/stable/generated/torch.cuda.Stream.html). However, using CUDA streams requires moving the model parameters into pinned memory. This allocation is handled by Pytorch under the hood, and can result in a significant spike in CPU RAM usage. Please consider this option if your CPU RAM is atleast 2X the size of the model you are group offloading.
In the following example we will use CUDA streams when group offloading the `WanTransformer3DModel`. When testing on an A100, this example will require 14GB of VRAM, 52GB of CPU RAM, but will generate a video in approximately 9 minutes.
```python
import torch
import numpy as np
from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanImageToVideoPipeline
from diffusers.hooks.group_offloading import apply_group_offloading
from diffusers.utils import export_to_video, load_image
from transformers import UMT5EncoderModel, CLIPVisionModel
# Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
image_encoder = CLIPVisionModel.from_pretrained(
model_id, subfolder="image_encoder", torch_dtype=torch.float32
)
text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
onload_device = torch.device("cuda")
offload_device = torch.device("cpu")
apply_group_offloading(text_encoder,
onload_device=onload_device,
offload_device=offload_device,
offload_type="block_level",
num_blocks_per_group=4
)
transformer.enable_group_offload(
onload_device=onload_device,
offload_device=offload_device,
offload_type="leaf_level",
use_stream=True
)
pipe = WanImageToVideoPipeline.from_pretrained(
model_id,
vae=vae,
transformer=transformer,
text_encoder=text_encoder,
image_encoder=image_encoder,
torch_dtype=torch.bfloat16
)
# Since we've offloaded the larger models alrady, we can move the rest of the model components to GPU
pipe.to("cuda")
image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
)
max_area = 720 * 832
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
prompt = (
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
"the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
)
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
num_frames = 33
output = pipe(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
guidance_scale=5.0,
).frames[0]
export_to_video(output, "wan-i2v.mp4", fps=16)
```
### Applying Layerwise Casting to the Transformer
Find more information about layerwise casting [here](../optimization/memory.md)
In this example, we will model offloading with layerwise casting. Layerwise casting will downcast each layer's weights to `torch.float8_e4m3fn`, temporarily upcast to `torch.bfloat16` during the forward pass of the layer, then revert to `torch.float8_e4m3fn` afterward. This approach reduces memory requirements by approximately 50% while introducing a minor quality reduction in the generated video due to the precision trade-off.
This example will require 20GB of VRAM.
```python
import torch
import numpy as np
from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanImageToVideoPipeline
from diffusers.hooks.group_offloading import apply_group_offloading
from diffusers.utils import export_to_video, load_image
from transformers import UMT5EncoderModel, CLIPVisionMode
model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
image_encoder = CLIPVisionModel.from_pretrained(
model_id, subfolder="image_encoder", torch_dtype=torch.float32
)
text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
pipe = WanImageToVideoPipeline.from_pretrained(
model_id,
vae=vae,
transformer=transformer,
text_encoder=text_encoder,
image_encoder=image_encoder,
torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload()
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg")
max_area = 720 * 832
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
prompt = (
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
"the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
)
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
num_frames = 33
output = pipe(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
num_inference_steps=50,
guidance_scale=5.0,
).frames[0]
export_to_video(output, "wan-i2v.mp4", fps=16)
```
### Using a Custom Scheduler
Wan can be used with many different schedulers, each with their own benefits regarding speed and generation quality. By default, Wan uses the `UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0)` scheduler. You can use a different scheduler as follows:
@@ -45,11 +388,10 @@ pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", scheduler
pipe.scheduler = <CUSTOM_SCHEDULER_HERE>
```
### Using single file loading with Wan
The `WanTransformer3DModel` and `AutoencoderKLWan` models support loading checkpoints in their original format via the `from_single_file` loading
method.
## Using Single File Loading with Wan 2.1
The `WanTransformer3DModel` and `AutoencoderKLWan` models support loading checkpoints in their original format via the `from_single_file` loading
method.
```python
import torch
@@ -61,6 +403,11 @@ transformer = WanTransformer3DModel.from_single_file(ckpt_path, torch_dtype=torc
pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", transformer=transformer)
```
## Recommendations for Inference:
- Keep `AutencoderKLWan` in `torch.float32` for better decoding quality.
- `num_frames` should satisfy the following constraint: `(num_frames - 1) % 4 == 0`
- For smaller resolution videos, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution videos, try higher values (between `7.0` and `12.0`). The default value is `3.0` for Wan.
## WanPipeline
[[autodoc]] WanPipeline
@@ -3,3 +3,7 @@
## Remote Decode
[[autodoc]] utils.remote_utils.remote_decode
## Remote Encode
[[autodoc]] utils.remote_utils.remote_encode
+8 -2
View File
@@ -36,7 +36,7 @@ Hybrid Inference offers a fast and simple way to offload local generation requir
## Available Models
* **VAE Decode 🖼️:** Quickly decode latent representations into high-quality images without compromising performance or workflow speed.
* **VAE Encode 🔢 (coming soon):** Efficiently encode images into latent representations for generation and training.
* **VAE Encode 🔢:** Efficiently encode images into latent representations for generation and training.
* **Text Encoders 📃 (coming soon):** Compute text embeddings for your prompts quickly and accurately, ensuring a smooth and high-quality workflow.
---
@@ -46,9 +46,15 @@ Hybrid Inference offers a fast and simple way to offload local generation requir
* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.
## Changelog
- March 10 2025: Added VAE encode
- March 2 2025: Initial release with VAE decoding
## Contents
The documentation is organized into two sections:
The documentation is organized into three sections:
* **VAE Decode** Learn the basics of how to use VAE Decode with Hybrid Inference.
* **VAE Encode** Learn the basics of how to use VAE Encode with Hybrid Inference.
* **API Reference** Dive into task-specific settings and parameters.
@@ -0,0 +1,183 @@
# Getting Started: VAE Encode with Hybrid Inference
VAE encode is used for training, image-to-image and image-to-video - turning into images or videos into latent representations.
## Memory
These tables demonstrate the VRAM requirements for VAE encode with SD v1 and SD XL on different GPUs.
For the majority of these GPUs the memory usage % dictates other models (text encoders, UNet/Transformer) must be offloaded, or tiled encoding has to be used which increases time taken and impacts quality.
<details><summary>SD v1.5</summary>
| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) |
|:------------------------------|:-------------|-----------------:|-------------:|--------------------:|-------------------:|
| NVIDIA GeForce RTX 4090 | 512x512 | 0.015 | 3.51901 | 0.015 | 3.51901 |
| NVIDIA GeForce RTX 4090 | 256x256 | 0.004 | 1.3154 | 0.005 | 1.3154 |
| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.402 | 47.1852 | 0.496 | 3.51901 |
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.078 | 12.2658 | 0.094 | 3.51901 |
| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.023 | 5.30105 | 0.023 | 5.30105 |
| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.006 | 1.98152 | 0.006 | 1.98152 |
| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 0.574 | 71.08 | 0.656 | 5.30105 |
| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.111 | 18.4772 | 0.14 | 5.30105 |
| NVIDIA GeForce RTX 3090 | 512x512 | 0.032 | 3.52782 | 0.032 | 3.52782 |
| NVIDIA GeForce RTX 3090 | 256x256 | 0.01 | 1.31869 | 0.009 | 1.31869 |
| NVIDIA GeForce RTX 3090 | 2048x2048 | 0.742 | 47.3033 | 0.954 | 3.52782 |
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.136 | 12.2965 | 0.207 | 3.52782 |
| NVIDIA GeForce RTX 3080 | 512x512 | 0.036 | 8.51761 | 0.036 | 8.51761 |
| NVIDIA GeForce RTX 3080 | 256x256 | 0.01 | 3.18387 | 0.01 | 3.18387 |
| NVIDIA GeForce RTX 3080 | 2048x2048 | 0.863 | 86.7424 | 1.191 | 8.51761 |
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.157 | 29.6888 | 0.227 | 8.51761 |
| NVIDIA GeForce RTX 3070 | 512x512 | 0.051 | 10.6941 | 0.051 | 10.6941 |
| NVIDIA GeForce RTX 3070 | 256x256 | 0.015 | 3.99743 | 0.015 | 3.99743 |
| NVIDIA GeForce RTX 3070 | 2048x2048 | 1.217 | 96.054 | 1.482 | 10.6941 |
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.223 | 37.2751 | 0.327 | 10.6941 |
</details>
<details><summary>SDXL</summary>
| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) |
|:------------------------------|:-------------|-----------------:|----------------------:|-----------------------:|-------------------:|
| NVIDIA GeForce RTX 4090 | 512x512 | 0.029 | 4.95707 | 0.029 | 4.95707 |
| NVIDIA GeForce RTX 4090 | 256x256 | 0.007 | 2.29666 | 0.007 | 2.29666 |
| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.873 | 66.3452 | 0.863 | 15.5649 |
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.142 | 15.5479 | 0.143 | 15.5479 |
| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.044 | 7.46735 | 0.044 | 7.46735 |
| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.01 | 3.4597 | 0.01 | 3.4597 |
| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 1.317 | 87.1615 | 1.291 | 23.447 |
| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.213 | 23.4215 | 0.214 | 23.4215 |
| NVIDIA GeForce RTX 3090 | 512x512 | 0.058 | 5.65638 | 0.058 | 5.65638 |
| NVIDIA GeForce RTX 3090 | 256x256 | 0.016 | 2.45081 | 0.016 | 2.45081 |
| NVIDIA GeForce RTX 3090 | 2048x2048 | 1.755 | 77.8239 | 1.614 | 18.4193 |
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.265 | 18.4023 | 0.265 | 18.4023 |
| NVIDIA GeForce RTX 3080 | 512x512 | 0.064 | 13.6568 | 0.064 | 13.6568 |
| NVIDIA GeForce RTX 3080 | 256x256 | 0.018 | 5.91728 | 0.018 | 5.91728 |
| NVIDIA GeForce RTX 3080 | 2048x2048 | OOM | OOM | 1.866 | 44.4717 |
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.302 | 44.4308 | 0.302 | 44.4308 |
| NVIDIA GeForce RTX 3070 | 512x512 | 0.093 | 17.1465 | 0.093 | 17.1465 |
| NVIDIA GeForce RTX 3070 | 256x256 | 0.025 | 7.42931 | 0.026 | 7.42931 |
| NVIDIA GeForce RTX 3070 | 2048x2048 | OOM | OOM | 2.674 | 55.8355 |
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.443 | 55.7841 | 0.443 | 55.7841 |
</details>
## Available VAEs
| | **Endpoint** | **Model** |
|:-:|:-----------:|:--------:|
| **Stable Diffusion v1** | [https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud](https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud) | [`stabilityai/sd-vae-ft-mse`](https://hf.co/stabilityai/sd-vae-ft-mse) |
| **Stable Diffusion XL** | [https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud](https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud) | [`madebyollin/sdxl-vae-fp16-fix`](https://hf.co/madebyollin/sdxl-vae-fp16-fix) |
| **Flux** | [https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud](https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud) | [`black-forest-labs/FLUX.1-schnell`](https://hf.co/black-forest-labs/FLUX.1-schnell) |
> [!TIP]
> Model support can be requested [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml).
## Code
> [!TIP]
> Install `diffusers` from `main` to run the code: `pip install git+https://github.com/huggingface/diffusers@main`
A helper method simplifies interacting with Hybrid Inference.
```python
from diffusers.utils.remote_utils import remote_encode
```
### Basic example
Let's encode an image, then decode it to demonstrate.
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"/>
</figure>
<details><summary>Code</summary>
```python
from diffusers.utils import load_image
from diffusers.utils.remote_utils import remote_decode
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg?download=true")
latent = remote_encode(
endpoint="https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/",
scaling_factor=0.3611,
shift_factor=0.1159,
)
decoded = remote_decode(
endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=latent,
scaling_factor=0.3611,
shift_factor=0.1159,
)
```
</details>
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/decoded.png"/>
</figure>
### Generation
Now let's look at a generation example, we'll encode the image, generate then remotely decode too!
<details><summary>Code</summary>
```python
import torch
from diffusers import StableDiffusionImg2ImgPipeline
from diffusers.utils import load_image
from diffusers.utils.remote_utils import remote_decode, remote_encode
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
torch_dtype=torch.float16,
variant="fp16",
vae=None,
).to("cuda")
init_image = load_image(
"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
)
init_image = init_image.resize((768, 512))
init_latent = remote_encode(
endpoint="https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/",
image=init_image,
scaling_factor=0.18215,
)
prompt = "A fantasy landscape, trending on artstation"
latent = pipe(
prompt=prompt,
image=init_latent,
strength=0.75,
output_type="latent",
).images
image = remote_decode(
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=latent,
scaling_factor=0.18215,
)
image.save("fantasy_landscape.jpg")
```
</details>
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/fantasy_landscape.png"/>
</figure>
## Integrations
* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.
+1 -1
View File
@@ -126,7 +126,7 @@ image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0).images[0]
image.save("output.png")
```
Some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
If you are using `torch<=2.6.0`, some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
```python
import torch
@@ -79,13 +79,13 @@ This command will prompt you for a token. Copy-paste yours from your [settings/t
### Target Modules
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma seperated string
applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma separated string
the exact modules for LoRA training. Here are some examples of target modules you can provide:
- for attention only layers: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0"`
- to train the same modules as in the fal trainer: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2"`
- to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2,norm1_context.linear, norm1.linear,norm.linear,proj_mlp,proj_out"`
> [!NOTE]
> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma seperated string:
> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma separated string:
> **single DiT blocks**: to target the ith single transformer block, add the prefix `single_transformer_blocks.i`, e.g. - `single_transformer_blocks.i.attn.to_k`
> **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k`
> [!NOTE]
@@ -378,7 +378,7 @@ def parse_args(input_args=None):
default=None,
help="the concept to use to initialize the new inserted tokens when training with "
"--train_text_encoder_ti = True. By default, new tokens (<si><si+1>) are initialized with random value. "
"Alternatively, you could specify a different word/words whos value will be used as the starting point for the new inserted tokens. "
"Alternatively, you could specify a different word/words whose value will be used as the starting point for the new inserted tokens. "
"--num_new_tokens_per_abstraction is ignored when initializer_concept is provided",
)
parser.add_argument(
@@ -662,7 +662,7 @@ def parse_args(input_args=None):
type=str,
default=None,
help=(
"The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. "
"The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. "
'E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only. For more examples refer to https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/README_flux.md'
),
)
@@ -662,7 +662,7 @@ def parse_args(input_args=None):
action="store_true",
default=False,
help=(
"Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. "
"Whether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. "
"Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`"
),
)
@@ -773,7 +773,7 @@ def parse_args(input_args=None):
action="store_true",
default=False,
help=(
"Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. "
"Whether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. "
"Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`"
),
)
@@ -1875,7 +1875,7 @@ def main(args):
# pack the statically computed variables appropriately here. This is so that we don't
# have to pass them to the dataloader.
# if --train_text_encoder_ti we need add_special_tokens to be True fo textual inversion
# if --train_text_encoder_ti we need add_special_tokens to be True for textual inversion
add_special_tokens = True if args.train_text_encoder_ti else False
if not train_dataset.custom_instance_prompts:
+201
View File
@@ -0,0 +1,201 @@
# Training CogView4 Control
This (experimental) example shows how to train Control LoRAs with [CogView4](https://huggingface.co/THUDM/CogView4-6B) by conditioning it with additional structural controls (like depth maps, poses, etc.). We provide a script for full fine-tuning, too, refer to [this section](#full-fine-tuning). To know more about CogView4 Control family, refer to the following resources:
To incorporate additional condition latents, we expand the input features of CogView-4 from 64 to 128. The first 64 channels correspond to the original input latents to be denoised, while the latter 64 channels correspond to control latents. This expansion happens on the `patch_embed` layer, where the combined latents are projected to the expected feature dimension of rest of the network. Inference is performed using the `CogView4ControlPipeline`.
> [!NOTE]
> **Gated model**
>
> As the model is gated, before using it with diffusers you first need to go to the [CogView4 Hugging Face page](https://huggingface.co/THUDM/CogView4-6B), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows youve accepted the gate. Use the command below to log in:
```bash
huggingface-cli login
```
The example command below shows how to launch fine-tuning for pose conditions. The dataset ([`raulc0399/open_pose_controlnet`](https://huggingface.co/datasets/raulc0399/open_pose_controlnet)) being used here already has the pose conditions of the original images, so we don't have to compute them.
```bash
accelerate launch train_control_lora_cogview4.py \
--pretrained_model_name_or_path="THUDM/CogView4-6B" \
--dataset_name="raulc0399/open_pose_controlnet" \
--output_dir="pose-control-lora" \
--mixed_precision="bf16" \
--train_batch_size=1 \
--rank=64 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--use_8bit_adam \
--learning_rate=1e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=5000 \
--validation_image="openpose.png" \
--validation_prompt="A couple, 4k photo, highly detailed" \
--offload \
--seed="0" \
--push_to_hub
```
`openpose.png` comes from [here](https://huggingface.co/Adapter/t2iadapter/resolve/main/openpose.png).
You need to install `diffusers` from the branch of [this PR](https://github.com/huggingface/diffusers/pull/9999). When it's merged, you should install `diffusers` from the `main`.
The training script exposes additional CLI args that might be useful to experiment with:
* `use_lora_bias`: When set, additionally trains the biases of the `lora_B` layer.
* `train_norm_layers`: When set, additionally trains the normalization scales. Takes care of saving and loading.
* `lora_layers`: Specify the layers you want to apply LoRA to. If you specify "all-linear", all the linear layers will be LoRA-attached.
### Training with DeepSpeed
It's possible to train with [DeepSpeed](https://github.com/microsoft/DeepSpeed), specifically leveraging the Zero2 system optimization. To use it, save the following config to an YAML file (feel free to modify as needed):
```yaml
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 1
gradient_clipping: 1.0
offload_optimizer_device: cpu
offload_param_device: cpu
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```
And then while launching training, pass the config file:
```bash
accelerate launch --config_file=CONFIG_FILE.yaml ...
```
### Inference
The pose images in our dataset were computed using the [`controlnet_aux`](https://github.com/huggingface/controlnet_aux) library. Let's install it first:
```bash
pip install controlnet_aux
```
And then we are ready:
```py
from controlnet_aux import OpenposeDetector
from diffusers import CogView4ControlPipeline
from diffusers.utils import load_image
from PIL import Image
import numpy as np
import torch
pipe = CogView4ControlPipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16).to("cuda")
pipe.load_lora_weights("...") # change this.
open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
# prepare pose condition.
url = "https://huggingface.co/Adapter/t2iadapter/resolve/main/people.jpg"
image = load_image(url)
image = open_pose(image, detect_resolution=512, image_resolution=1024)
image = np.array(image)[:, :, ::-1]
image = Image.fromarray(np.uint8(image))
prompt = "A couple, 4k photo, highly detailed"
gen_images = pipe(
prompt=prompt,
control_image=image,
num_inference_steps=50,
joint_attention_kwargs={"scale": 0.9},
guidance_scale=25.,
).images[0]
gen_images.save("output.png")
```
## Full fine-tuning
We provide a non-LoRA version of the training script `train_control_cogview4.py`. Here is an example command:
```bash
accelerate launch --config_file=accelerate_ds2.yaml train_control_cogview4.py \
--pretrained_model_name_or_path="THUDM/CogView4-6B" \
--dataset_name="raulc0399/open_pose_controlnet" \
--output_dir="pose-control" \
--mixed_precision="bf16" \
--train_batch_size=2 \
--dataloader_num_workers=4 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--use_8bit_adam \
--proportion_empty_prompts=0.2 \
--learning_rate=5e-5 \
--adam_weight_decay=1e-4 \
--report_to="wandb" \
--lr_scheduler="cosine" \
--lr_warmup_steps=1000 \
--checkpointing_steps=1000 \
--max_train_steps=10000 \
--validation_steps=200 \
--validation_image "2_pose_1024.jpg" "3_pose_1024.jpg" \
--validation_prompt "two friends sitting by each other enjoying a day at the park, full hd, cinematic" "person enjoying a day at the park, full hd, cinematic" \
--offload \
--seed="0" \
--push_to_hub
```
Change the `validation_image` and `validation_prompt` as needed.
For inference, this time, we will run:
```py
from controlnet_aux import OpenposeDetector
from diffusers import CogView4ControlPipeline, CogView4Transformer2DModel
from diffusers.utils import load_image
from PIL import Image
import numpy as np
import torch
transformer = CogView4Transformer2DModel.from_pretrained("...") # change this.
pipe = CogView4ControlPipeline.from_pretrained(
"THUDM/CogView4-6B", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
# prepare pose condition.
url = "https://huggingface.co/Adapter/t2iadapter/resolve/main/people.jpg"
image = load_image(url)
image = open_pose(image, detect_resolution=512, image_resolution=1024)
image = np.array(image)[:, :, ::-1]
image = Image.fromarray(np.uint8(image))
prompt = "A couple, 4k photo, highly detailed"
gen_images = pipe(
prompt=prompt,
control_image=image,
num_inference_steps=50,
guidance_scale=25.,
).images[0]
gen_images.save("output.png")
```
## Things to note
* The scripts provided in this directory are experimental and educational. This means we may have to tweak things around to get good results on a given condition. We believe this is best done with the community 🤗
* The scripts are not memory-optimized but we offload the VAE and the text encoders to CPU when they are not used if `--offload` is specified.
* We can extract LoRAs from the fully fine-tuned model. While we currently don't provide any utilities for that, users are welcome to refer to [this script](https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/master/control_lora_create.py) that provides a similar functionality.
@@ -0,0 +1,6 @@
transformers==4.47.0
wandb
torch
torchvision
accelerate==1.2.0
peft>=0.14.0
File diff suppressed because it is too large Load Diff
+22 -22
View File
@@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2025 The DEVAIEXP Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1070,32 +1070,32 @@ class StableDiffusionXLTilingPipeline(
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
else:
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
add_time_ids = self._get_add_time_ids(
original_size,
crops_coords_top_left[row][col],
target_size,
add_time_ids = self._get_add_time_ids(
original_size,
crops_coords_top_left[row][col],
target_size,
dtype=prompt_embeds.dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
)
if negative_original_size is not None and negative_target_size is not None:
negative_add_time_ids = self._get_add_time_ids(
negative_original_size,
negative_crops_coords_top_left[row][col],
negative_target_size,
dtype=prompt_embeds.dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
)
if negative_original_size is not None and negative_target_size is not None:
negative_add_time_ids = self._get_add_time_ids(
negative_original_size,
negative_crops_coords_top_left[row][col],
negative_target_size,
dtype=prompt_embeds.dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
)
else:
negative_add_time_ids = add_time_ids
else:
negative_add_time_ids = add_time_ids
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
prompt_embeds = prompt_embeds.to(device)
add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.to(device)
add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
addition_embed_type_row.append((prompt_embeds, add_text_embeds, add_time_ids))
embeddings_and_added_time.append(addition_embed_type_row)
+1 -3
View File
@@ -152,9 +152,7 @@ def log_validation(
validation_prompt = log["validation_prompt"]
validation_image = log["validation_image"]
formatted_images = []
formatted_images.append(np.asarray(validation_image))
formatted_images = [np.asarray(validation_image)]
for image in images:
formatted_images.append(np.asarray(image))
+1 -3
View File
@@ -166,9 +166,7 @@ def log_validation(
validation_prompt = log["validation_prompt"]
validation_image = log["validation_image"]
formatted_images = []
formatted_images.append(np.asarray(validation_image))
formatted_images = [np.asarray(validation_image)]
for image in images:
formatted_images.append(np.asarray(image))
+2 -2
View File
@@ -1283,8 +1283,8 @@ def main(args):
noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
# Get the text embedding for conditioning
prompt_embeds = batch["prompt_embeds"]
pooled_prompt_embeds = batch["pooled_prompt_embeds"]
prompt_embeds = batch["prompt_embeds"].to(dtype=weight_dtype)
pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(dtype=weight_dtype)
# controlnet(s) inference
controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
+1 -3
View File
@@ -157,9 +157,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
validation_prompt = log["validation_prompt"]
validation_image = log["validation_image"]
formatted_images = []
formatted_images.append(np.asarray(validation_image))
formatted_images = [np.asarray(validation_image)]
for image in images:
formatted_images.append(np.asarray(image))
@@ -381,9 +381,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
validation_prompt = log["validation_prompt"]
validation_image = log["validation_image"]
formatted_images = []
formatted_images.append(np.asarray(validation_image))
formatted_images = [np.asarray(validation_image)]
for image in images:
formatted_images.append(np.asarray(image))
@@ -164,9 +164,7 @@ def log_validation(
validation_prompt = log["validation_prompt"]
validation_image = log["validation_image"]
formatted_images = []
formatted_images.append(np.asarray(validation_image))
formatted_images = [np.asarray(validation_image)]
for image in images:
formatted_images.append(np.asarray(image))
@@ -50,51 +50,116 @@ python flux_inference.py
The script loads the text encoders onto the CPU and the Flux transformer and VAE models onto the TPU. The first time the script runs, the compilation time is longer, while the cache stores the compiled programs. On subsequent runs, compilation is much faster and the subsequent passes being the fastest.
On a Trillium v6e-4, you should expect ~9 sec / 4 images or 2.25 sec / image (as devices run generation in parallel):
On a Trillium v6e-4, you should expect ~6 sec / 4 images or 1.5 sec / image (as devices run generation in parallel):
```bash
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
Loading checkpoint shards: 100%|███████████████████████████████| 2/2 [00:00<00:00, 7.01it/s]
Loading pipeline components...: 40%|██████████ | 2/5 [00:00<00:00, 3.78it/s]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Loading pipeline components...: 100%|██████████████████████████| 5/5 [00:00<00:00, 6.72it/s]
2025-01-10 00:51:25 [info ] loading flux from black-forest-labs/FLUX.1-dev
2025-01-10 00:51:25 [info ] loading flux from black-forest-labs/FLUX.1-dev
2025-01-10 00:51:26 [info ] loading flux from black-forest-labs/FLUX.1-dev
2025-01-10 00:51:26 [info ] loading flux from black-forest-labs/FLUX.1-dev
Loading pipeline components...: 100%|██████████████████████████| 3/3 [00:00<00:00, 4.29it/s]
Loading pipeline components...: 100%|██████████████████████████| 3/3 [00:00<00:00, 3.26it/s]
Loading pipeline components...: 100%|██████████████████████████| 3/3 [00:00<00:00, 3.27it/s]
Loading pipeline components...: 100%|██████████████████████████| 3/3 [00:00<00:00, 3.25it/s]
2025-01-10 00:51:34 [info ] starting compilation run...
2025-01-10 00:51:35 [info ] starting compilation run...
2025-01-10 00:51:37 [info ] starting compilation run...
2025-01-10 00:51:37 [info ] starting compilation run...
2025-01-10 00:52:52 [info ] compilation took 78.5155531649998 sec.
2025-01-10 00:52:53 [info ] starting inference run...
2025-01-10 00:52:57 [info ] compilation took 79.52986721400157 sec.
2025-01-10 00:52:57 [info ] compilation took 81.91776501700042 sec.
2025-01-10 00:52:57 [info ] compilation took 80.24951512600092 sec.
2025-01-10 00:52:57 [info ] starting inference run...
2025-01-10 00:52:57 [info ] starting inference run...
2025-01-10 00:52:58 [info ] starting inference run...
2025-01-10 00:53:22 [info ] inference time: 25.112665320000815
2025-01-10 00:53:30 [info ] inference time: 7.7019307739992655
2025-01-10 00:53:38 [info ] inference time: 7.693858365000779
2025-01-10 00:53:46 [info ] inference time: 7.690621814001133
2025-01-10 00:53:53 [info ] inference time: 7.679490454000188
2025-01-10 00:54:01 [info ] inference time: 7.68949568500102
2025-01-10 00:54:09 [info ] inference time: 7.686633744000574
2025-01-10 00:54:16 [info ] inference time: 7.696786873999372
2025-01-10 00:54:24 [info ] inference time: 7.691988694999964
2025-01-10 00:54:32 [info ] inference time: 7.700649563999832
2025-01-10 00:54:39 [info ] inference time: 7.684993574001055
2025-01-10 00:54:47 [info ] inference time: 7.68343457499941
2025-01-10 00:54:55 [info ] inference time: 7.667921153999487
2025-01-10 00:55:02 [info ] inference time: 7.683585194001353
2025-01-10 00:55:06 [info ] avg. inference over 15 iterations took 8.61202360273334 sec.
2025-01-10 00:55:07 [info ] avg. inference over 15 iterations took 8.952725123600006 sec.
2025-01-10 00:55:10 [info ] inference time: 7.673799695001435
2025-01-10 00:55:10 [info ] avg. inference over 15 iterations took 8.849190365400379 sec.
2025-01-10 00:55:10 [info ] saved metric information as /tmp/metrics_report.txt
2025-01-10 00:55:12 [info ] avg. inference over 15 iterations took 8.940161458400205 sec.
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 7.06it/s]
Loading pipeline components...: 60%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 3/5 [00:00<00:00, 6.80it/s]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 6.28it/s]
2025-03-14 21:17:53 [info ] loading flux from black-forest-labs/FLUX.1-dev
2025-03-14 21:17:53 [info ] loading flux from black-forest-labs/FLUX.1-dev
Loading pipeline components...: 0%| | 0/3 [00:00<?, ?it/s]2025-03-14 21:17:53 [info ] loading flux from black-forest-labs/FLUX.1-dev
2025-03-14 21:17:53 [info ] loading flux from black-forest-labs/FLUX.1-dev
Loading pipeline components...: 0%| | 0/3 [00:00<?, ?it/s]2025-03-14 21:17:54 [info ] loading flux from black-forest-labs/FLUX.1-dev
2025-03-14 21:17:54 [info ] loading flux from black-forest-labs/FLUX.1-dev
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.66it/s]
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 4.48it/s]
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.32it/s]
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.69it/s]
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.74it/s]
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.10it/s]
2025-03-14 21:17:56 [info ] loading flux from black-forest-labs/FLUX.1-dev
Loading pipeline components...: 0%| | 0/3 [00:00<?, ?it/s]2025-03-14 21:17:56 [info ] loading flux from black-forest-labs/FLUX.1-dev
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.55it/s]
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:02<00:00, 1.46it/s]
2025-03-14 21:18:34 [info ] starting compilation run...
2025-03-14 21:18:37 [info ] starting compilation run...
2025-03-14 21:18:38 [info ] starting compilation run...
2025-03-14 21:18:39 [info ] starting compilation run...
2025-03-14 21:18:41 [info ] starting compilation run...
2025-03-14 21:18:41 [info ] starting compilation run...
2025-03-14 21:18:42 [info ] starting compilation run...
2025-03-14 21:18:43 [info ] starting compilation run...
82%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 23/28 [13:35<03:04, 36.80s/it]2025-03-14 21:33:42.057559: W torch_xla/csrc/runtime/pjrt_computation_client.cc:667] Failed to deserialize executable: INTERNAL: TfrtTpuExecutable proto deserialization failed while parsing core program!
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:27<00:00, 35.28s/it]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:27<00:00, 35.26s/it]
2025-03-14 21:36:38 [info ] compilation took 1079.3314765350078 sec.
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:12<00:00, 34.73s/it]
2025-03-14 21:36:38 [info ] starting inference run...
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:12<00:00, 34.73s/it]
2025-03-14 21:36:38 [info ] compilation took 1081.89390801001 sec.
2025-03-14 21:36:39 [info ] starting inference run...
2025-03-14 21:36:39 [info ] compilation took 1077.1543154849933 sec.
2025-03-14 21:36:39 [info ] compilation took 1075.7239800530078 sec.
2025-03-14 21:36:39 [info ] starting inference run...
2025-03-14 21:36:40 [info ] starting inference run...
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:22<00:00, 35.10s/it]
2025-03-14 21:36:50 [info ] compilation took 1088.1632604240003 sec.
2025-03-14 21:36:50 [info ] starting inference run...
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:28<00:00, 35.32s/it]
2025-03-14 21:36:55 [info ] compilation took 1096.8027802760043 sec.
2025-03-14 21:36:56 [info ] starting inference run...
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:59<00:00, 36.40s/it]
2025-03-14 21:37:08 [info ] compilation took 1113.8591305939917 sec.
2025-03-14 21:37:08 [info ] starting inference run...
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:55<00:00, 36.26s/it]
2025-03-14 21:37:22 [info ] compilation took 1120.5590810020076 sec.
2025-03-14 21:37:22 [info ] starting inference run...
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 5.00it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:09<00:00, 2.98it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:06<00:00, 4.08it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:09<00:00, 2.82it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:08<00:00, 3.34it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:06<00:00, 4.22it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:06<00:00, 4.09it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:11<00:00, 2.41it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:06<00:00, 4.50it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 5.10it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 5.27it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 4.80it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 5.39it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 5.39it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.67it/s]
29%|█████████████████████████████████████████████████████████████████████████████▍ | 8/28 [00:01<00:03, 6.08it/s]/home/jfacevedo_google_com/diffusers/src/diffusers/image_processor.py:147: RuntimeWarning: invalid value encountered in cast
images = (images * 255).round().astype("uint8")
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.82it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.93it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.02it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.02it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.98it/s]
71%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 20/28 [00:03<00:01, 6.03it/s]2025-03-14 21:38:32 [info ] inference time: 5.962021178987925
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.89it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.09it/s]
2025-03-14 21:38:32 [info ] avg. inference over 5 iterations took 7.2685392687970305 sec.
2025-03-14 21:38:32 [info ] avg. inference over 5 iterations took 7.402720856998348 sec.
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.01it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.89it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.96it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.06it/s]
71%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 20/28 [00:03<00:01, 6.01it/s]2025-03-14 21:38:38 [info ] inference time: 5.950578948002658
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.87it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.09it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.00it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.86it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.99it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.05it/s]
2025-03-14 21:38:43 [info ] avg. inference over 5 iterations took 6.763298449796276 sec.
71%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 20/28 [00:03<00:01, 6.04it/s]2025-03-14 21:38:44 [info ] inference time: 5.949129879008979
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.92it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.10it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.02it/s]
39%|██████████████████████████████████████████████████████████████████████████████████████████████████████████ | 11/28 [00:01<00:02, 5.98it/s]2025-03-14 21:38:46 [info ] avg. inference over 5 iterations took 7.221068455604836 sec.
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.96it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.08it/s]
93%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 26/28 [00:04<00:00, 5.92it/s]2025-03-14 21:38:50 [info ] inference time: 5.954778069004533
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.90it/s]
11%|█████████████████████████████ | 3/28 [00:00<00:04, 6.03it/s]2025-03-14 21:38:50 [info ] avg. inference over 5 iterations took 6.05970350120042 sec.
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.02it/s]
32%|███████████████████████████████████████████████████████████████████████████████████████ | 9/28 [00:01<00:03, 5.99it/s]2025-03-14 21:38:51 [info ] avg. inference over 5 iterations took 6.018543455796316 sec.
54%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 15/28 [00:02<00:02, 6.00it/s]2025-03-14 21:38:52 [info ] avg. inference over 5 iterations took 5.9609976705978625 sec.
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.97it/s]
2025-03-14 21:38:56 [info ] inference time: 5.944058528999449
2025-03-14 21:38:56 [info ] avg. inference over 5 iterations took 5.952113320800708 sec.
2025-03-14 21:38:56 [info ] saved metric information as /tmp/metrics_report.txt
```
@@ -9,6 +9,7 @@ import torch_xla.debug.metrics as met
import torch_xla.debug.profiler as xp
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.runtime as xr
from torch_xla.experimental.custom_kernel import FlashAttention
from diffusers import FluxPipeline
@@ -36,6 +37,19 @@ def _main(index, args, text_pipe, ckpt_id):
ckpt_id, text_encoder=None, tokenizer=None, text_encoder_2=None, tokenizer_2=None, torch_dtype=torch.bfloat16
).to(device0)
flux_pipe.transformer.enable_xla_flash_attention(partition_spec=("data", None, None, None), is_flux=True)
FlashAttention.DEFAULT_BLOCK_SIZES = {
"block_q": 1536,
"block_k_major": 1536,
"block_k": 1536,
"block_b": 1536,
"block_q_major_dkv": 1536,
"block_k_major_dkv": 1536,
"block_q_dkv": 1536,
"block_k_dkv": 1536,
"block_q_dq": 1536,
"block_k_dq": 1536,
"block_k_major_dq": 1536,
}
prompt = "photograph of an electronics chip in the shape of a race car with trillium written on its side"
width = args.width
@@ -69,14 +83,14 @@ def _main(index, args, text_pipe, ckpt_id):
xm.set_rng_state(seed=unique_seed, device=device0)
times = []
logger.info("starting inference run...")
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds, text_ids = text_pipe.encode_prompt(
prompt=prompt, prompt_2=None, max_sequence_length=512
)
prompt_embeds = prompt_embeds.to(device0)
pooled_prompt_embeds = pooled_prompt_embeds.to(device0)
for _ in range(args.itters):
ts = perf_counter()
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds, text_ids = text_pipe.encode_prompt(
prompt=prompt, prompt_2=None, max_sequence_length=512
)
prompt_embeds = prompt_embeds.to(device0)
pooled_prompt_embeds = pooled_prompt_embeds.to(device0)
if args.profile:
xp.trace_detached(f"localhost:{profiler_port}", str(profile_path), duration_ms=profile_duration)
@@ -92,7 +106,7 @@ def _main(index, args, text_pipe, ckpt_id):
if index == 0:
logger.info(f"inference time: {inference_time}")
times.append(inference_time)
logger.info(f"avg. inference over {args.itters} iterations took {sum(times)/len(times)} sec.")
logger.info(f"avg. inference over {args.itters} iterations took {sum(times) / len(times)} sec.")
image.save(f"/tmp/inference_out-{index}.png")
if index == 0:
metrics_report = met.metrics_report()
@@ -141,9 +141,7 @@ def log_validation(vae, unet, adapter, args, accelerator, weight_dtype, step):
validation_prompt = log["validation_prompt"]
validation_image = log["validation_image"]
formatted_images = []
formatted_images.append(np.asarray(validation_image))
formatted_images = [np.asarray(validation_image)]
for image in images:
formatted_images.append(np.asarray(image))
+13 -2
View File
@@ -53,8 +53,18 @@ args = parser.parse_args()
# this is specific to `AdaLayerNormContinuous`:
# diffusers implementation split the linear projection into the scale, shift while CogView4 split it tino shift, scale
def swap_scale_shift(weight, dim):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
"""
Swap the scale and shift components in the weight tensor.
Args:
weight (torch.Tensor): The original weight tensor.
dim (int): The dimension along which to split.
Returns:
torch.Tensor: The modified weight tensor with scale and shift swapped.
"""
shift, scale = weight.chunk(2, dim=dim)
new_weight = torch.cat([scale, shift], dim=dim)
return new_weight
@@ -200,6 +210,7 @@ def main(args):
"norm_num_groups": 32,
"sample_size": 1024,
"scaling_factor": 1.0,
"shift_factor": 0.0,
"force_upcast": True,
"use_quant_conv": False,
"use_post_quant_conv": False,
@@ -25,9 +25,15 @@ import argparse
import torch
from tqdm import tqdm
from transformers import GlmForCausalLM, PreTrainedTokenizerFast
from transformers import GlmModel, PreTrainedTokenizerFast
from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
from diffusers import (
AutoencoderKL,
CogView4ControlPipeline,
CogView4Pipeline,
CogView4Transformer2DModel,
FlowMatchEulerDiscreteScheduler,
)
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
@@ -112,6 +118,12 @@ parser.add_argument(
default=128,
help="Maximum size for positional embeddings.",
)
parser.add_argument(
"--control",
action="store_true",
default=False,
help="Whether to use control model.",
)
args = parser.parse_args()
@@ -150,13 +162,15 @@ def convert_megatron_transformer_checkpoint_to_diffusers(
Returns:
dict: The converted state dictionary compatible with Diffusers.
"""
ckpt = torch.load(ckpt_path, map_location="cpu")
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
mega = ckpt["model"]
new_state_dict = {}
# Patch Embedding
new_state_dict["patch_embed.proj.weight"] = mega["encoder_expand_linear.weight"].reshape(hidden_size, 64)
new_state_dict["patch_embed.proj.weight"] = mega["encoder_expand_linear.weight"].reshape(
hidden_size, 128 if args.control else 64
)
new_state_dict["patch_embed.proj.bias"] = mega["encoder_expand_linear.bias"]
new_state_dict["patch_embed.text_proj.weight"] = mega["text_projector.weight"]
new_state_dict["patch_embed.text_proj.bias"] = mega["text_projector.bias"]
@@ -189,14 +203,8 @@ def convert_megatron_transformer_checkpoint_to_diffusers(
block_prefix = f"transformer_blocks.{i}."
# AdaLayerNorm
new_state_dict[block_prefix + "norm1.linear.weight"] = swap_scale_shift(
mega[f"decoder.layers.{i}.adaln.weight"], dim=0
)
new_state_dict[block_prefix + "norm1.linear.bias"] = swap_scale_shift(
mega[f"decoder.layers.{i}.adaln.bias"], dim=0
)
# QKV
new_state_dict[block_prefix + "norm1.linear.weight"] = mega[f"decoder.layers.{i}.adaln.weight"]
new_state_dict[block_prefix + "norm1.linear.bias"] = mega[f"decoder.layers.{i}.adaln.bias"]
qkv_weight = mega[f"decoder.layers.{i}.self_attention.linear_qkv.weight"]
qkv_bias = mega[f"decoder.layers.{i}.self_attention.linear_qkv.bias"]
@@ -221,7 +229,7 @@ def convert_megatron_transformer_checkpoint_to_diffusers(
# Attention Output
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = mega[
f"decoder.layers.{i}.self_attention.linear_proj.weight"
].T
]
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = mega[
f"decoder.layers.{i}.self_attention.linear_proj.bias"
]
@@ -252,7 +260,7 @@ def convert_cogview4_vae_checkpoint_to_diffusers(ckpt_path, vae_config):
Returns:
dict: The converted VAE state dictionary compatible with Diffusers.
"""
original_state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
original_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False)["state_dict"]
return convert_ldm_vae_checkpoint(original_state_dict, vae_config)
@@ -286,7 +294,7 @@ def main(args):
)
transformer = CogView4Transformer2DModel(
patch_size=2,
in_channels=16,
in_channels=32 if args.control else 16,
num_layers=args.num_layers,
attention_head_dim=args.attention_head_dim,
num_attention_heads=args.num_heads,
@@ -317,6 +325,7 @@ def main(args):
"norm_num_groups": 32,
"sample_size": 1024,
"scaling_factor": 1.0,
"shift_factor": 0.0,
"force_upcast": True,
"use_quant_conv": False,
"use_post_quant_conv": False,
@@ -331,7 +340,7 @@ def main(args):
# Load the text encoder and tokenizer
text_encoder_id = "THUDM/glm-4-9b-hf"
tokenizer = PreTrainedTokenizerFast.from_pretrained(text_encoder_id)
text_encoder = GlmForCausalLM.from_pretrained(
text_encoder = GlmModel.from_pretrained(
text_encoder_id,
cache_dir=args.text_encoder_cache_dir,
torch_dtype=torch.bfloat16 if args.dtype == "bf16" else torch.float32,
@@ -345,13 +354,22 @@ def main(args):
)
# Create the pipeline
pipe = CogView4Pipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
transformer=transformer,
scheduler=scheduler,
)
if args.control:
pipe = CogView4ControlPipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
transformer=transformer,
scheduler=scheduler,
)
else:
pipe = CogView4Pipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
transformer=transformer,
scheduler=scheduler,
)
# Save the converted pipeline
pipe.save_pretrained(
+2 -2
View File
@@ -5,7 +5,7 @@ import torch
from safetensors.torch import load_file
from transformers import AutoModel, AutoTokenizer
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, LuminaNextDiT2DModel, LuminaText2ImgPipeline
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, LuminaNextDiT2DModel, LuminaPipeline
def main(args):
@@ -115,7 +115,7 @@ def main(args):
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
text_encoder = AutoModel.from_pretrained("google/gemma-2b")
pipeline = LuminaText2ImgPipeline(
pipeline = LuminaPipeline(
tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, scheduler=scheduler
)
pipeline.save_pretrained(args.dump_path)
+14 -14
View File
@@ -2,20 +2,14 @@ __version__ = "0.33.0.dev0"
from typing import TYPE_CHECKING
from diffusers.quantizers import quantization_config
from diffusers.utils import dummy_gguf_objects
from diffusers.utils.import_utils import (
is_bitsandbytes_available,
is_gguf_available,
is_optimum_quanto_version,
is_torchao_available,
)
from .utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
is_accelerate_available,
is_bitsandbytes_available,
is_flax_available,
is_gguf_available,
is_k_diffusion_available,
is_librosa_available,
is_note_seq_available,
@@ -24,6 +18,7 @@ from .utils import (
is_scipy_available,
is_sentencepiece_available,
is_torch_available,
is_torchao_available,
is_torchsde_available,
is_transformers_available,
)
@@ -65,7 +60,7 @@ _import_structure = {
}
try:
if not is_bitsandbytes_available():
if not is_torch_available() and not is_accelerate_available() and not is_bitsandbytes_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_bitsandbytes_objects
@@ -77,7 +72,7 @@ else:
_import_structure["quantizers.quantization_config"].append("BitsAndBytesConfig")
try:
if not is_gguf_available():
if not is_torch_available() and not is_accelerate_available() and not is_gguf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_gguf_objects
@@ -89,7 +84,7 @@ else:
_import_structure["quantizers.quantization_config"].append("GGUFQuantizationConfig")
try:
if not is_torchao_available():
if not is_torch_available() and not is_accelerate_available() and not is_torchao_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_torchao_objects
@@ -101,7 +96,7 @@ else:
_import_structure["quantizers.quantization_config"].append("TorchAoConfig")
try:
if not is_optimum_quanto_available():
if not is_torch_available() and not is_accelerate_available() and not is_optimum_quanto_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_optimum_quanto_objects
@@ -112,7 +107,6 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["quantizers.quantization_config"].append("QuantoConfig")
try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
@@ -351,6 +345,7 @@ else:
"CogVideoXPipeline",
"CogVideoXVideoToVideoPipeline",
"CogView3PlusPipeline",
"CogView4ControlPipeline",
"CogView4Pipeline",
"ConsisIDPipeline",
"CycleDiffusionPipeline",
@@ -409,7 +404,9 @@ else:
"LEditsPPPipelineStableDiffusionXL",
"LTXImageToVideoPipeline",
"LTXPipeline",
"Lumina2Pipeline",
"Lumina2Text2ImgPipeline",
"LuminaPipeline",
"LuminaText2ImgPipeline",
"MarigoldDepthPipeline",
"MarigoldIntrinsicsPipeline",
@@ -893,6 +890,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
CogVideoXPipeline,
CogVideoXVideoToVideoPipeline,
CogView3PlusPipeline,
CogView4ControlPipeline,
CogView4Pipeline,
ConsisIDPipeline,
CycleDiffusionPipeline,
@@ -951,7 +949,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LEditsPPPipelineStableDiffusionXL,
LTXImageToVideoPipeline,
LTXPipeline,
Lumina2Pipeline,
Lumina2Text2ImgPipeline,
LuminaPipeline,
LuminaText2ImgPipeline,
MarigoldDepthPipeline,
MarigoldIntrinsicsPipeline,
+1 -3
View File
@@ -804,9 +804,7 @@ class SD3IPAdapterMixin:
}
self.register_modules(
feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs).to(
self.device, dtype=self.dtype
),
feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs),
image_encoder=SiglipVisionModel.from_pretrained(
image_encoder_subfolder, torch_dtype=self.dtype, **kwargs
).to(self.device),
+6 -2
View File
@@ -423,8 +423,12 @@ def _load_lora_into_text_encoder(
# Unsafe code />
if prefix is not None and not state_dict:
logger.info(
f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. This is safe to ignore if LoRA state dict didn't originally have any {text_encoder.__class__.__name__} related params. Open an issue if you think it's unexpected: https://github.com/huggingface/diffusers/issues/new"
logger.warning(
f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. "
"This is safe to ignore if LoRA state dict didn't originally have any "
f"{text_encoder.__class__.__name__} related params. You can also try specifying `prefix=None` "
"to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
"https://github.com/huggingface/diffusers/issues/new"
)
@@ -1348,3 +1348,56 @@ def _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict):
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
return converted_state_dict
def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
converted_state_dict = {}
original_state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()}
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in original_state_dict})
is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
for i in range(num_blocks):
# Self-attention
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_A.weight"] = original_state_dict.pop(
f"blocks.{i}.self_attn.{o}.lora_A.weight"
)
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = original_state_dict.pop(
f"blocks.{i}.self_attn.{o}.lora_B.weight"
)
# Cross-attention
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
f"blocks.{i}.cross_attn.{o}.lora_A.weight"
)
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
f"blocks.{i}.cross_attn.{o}.lora_B.weight"
)
if is_i2v_lora:
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
f"blocks.{i}.cross_attn.{o}.lora_A.weight"
)
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
f"blocks.{i}.cross_attn.{o}.lora_B.weight"
)
# FFN
for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_A.weight"] = original_state_dict.pop(
f"blocks.{i}.{o}.lora_A.weight"
)
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.weight"] = original_state_dict.pop(
f"blocks.{i}.{o}.lora_B.weight"
)
if len(original_state_dict) > 0:
raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}")
for key in list(converted_state_dict.keys()):
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
return converted_state_dict
+75 -25
View File
@@ -42,6 +42,7 @@ from .lora_conversion_utils import (
_convert_kohya_flux_lora_to_diffusers,
_convert_non_diffusers_lora_to_diffusers,
_convert_non_diffusers_lumina2_lora_to_diffusers,
_convert_non_diffusers_wan_lora_to_diffusers,
_convert_xlabs_flux_lora_to_diffusers,
_maybe_map_sgm_blocks_to_diffusers,
)
@@ -451,7 +452,11 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs):
@@ -472,7 +477,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)
class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
@@ -891,7 +896,11 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_encoder_2"], **kwargs):
@@ -912,7 +921,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)
class SD3LoraLoaderMixin(LoraBaseMixin):
@@ -1290,7 +1299,11 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer
@@ -1312,7 +1325,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)
class FluxLoraLoaderMixin(LoraBaseMixin):
@@ -1828,7 +1841,11 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
)
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
@@ -1849,7 +1866,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)
# We override this here account for `_transformer_norm_layers` and `_overwritten_params`.
def unload_lora_weights(self, reset_to_overwritten_params=False):
@@ -2548,7 +2565,11 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
@@ -2566,7 +2587,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)
class Mochi1LoraLoaderMixin(LoraBaseMixin):
@@ -2852,7 +2873,11 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
@@ -2871,7 +2896,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)
class LTXVideoLoraLoaderMixin(LoraBaseMixin):
@@ -3157,7 +3182,11 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
@@ -3176,7 +3205,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)
class SanaLoraLoaderMixin(LoraBaseMixin):
@@ -3462,7 +3491,11 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
@@ -3481,7 +3514,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)
class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
@@ -3770,7 +3803,11 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
@@ -3789,7 +3826,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)
class Lumina2LoraLoaderMixin(LoraBaseMixin):
@@ -4079,7 +4116,11 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
@@ -4098,7 +4139,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)
class WanLoraLoaderMixin(LoraBaseMixin):
@@ -4111,7 +4152,6 @@ class WanLoraLoaderMixin(LoraBaseMixin):
@classmethod
@validate_hf_hub_args
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
@@ -4198,6 +4238,8 @@ class WanLoraLoaderMixin(LoraBaseMixin):
user_agent=user_agent,
allow_pickle=allow_pickle,
)
if any(k.startswith("diffusion_model.") for k in state_dict):
state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict)
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
@@ -4384,7 +4426,11 @@ class WanLoraLoaderMixin(LoraBaseMixin):
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
@@ -4403,7 +4449,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)
class CogView4LoraLoaderMixin(LoraBaseMixin):
@@ -4689,7 +4735,11 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
@@ -4708,7 +4758,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components)
super().unfuse_lora(components=components, **kwargs)
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
+6 -2
View File
@@ -354,8 +354,12 @@ class PeftAdapterMixin:
# Unsafe code />
if prefix is not None and not state_dict:
logger.info(
f"No LoRA keys associated to {self.__class__.__name__} found with the {prefix=}. This is safe to ignore if LoRA state dict didn't originally have any {self.__class__.__name__} related params. Open an issue if you think it's unexpected: https://github.com/huggingface/diffusers/issues/new"
logger.warning(
f"No LoRA keys associated to {self.__class__.__name__} found with the {prefix=}. "
"This is safe to ignore if LoRA state dict didn't originally have any "
f"{self.__class__.__name__} related params. You can also try specifying `prefix=None` "
"to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
"https://github.com/huggingface/diffusers/issues/new"
)
def save_lora_adapter(
+13 -5
View File
@@ -741,10 +741,14 @@ class Attention(nn.Module):
if out_dim == 3:
if attention_mask.shape[0] < batch_size * head_size:
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
attention_mask = attention_mask.repeat_interleave(
head_size, dim=0, output_size=attention_mask.shape[0] * head_size
)
elif out_dim == 4:
attention_mask = attention_mask.unsqueeze(1)
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
attention_mask = attention_mask.repeat_interleave(
head_size, dim=1, output_size=attention_mask.shape[1] * head_size
)
return attention_mask
@@ -2335,7 +2339,9 @@ class FluxAttnProcessor2_0:
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
@@ -3704,8 +3710,10 @@ class StableAudioAttnProcessor2_0:
if kv_heads != attn.heads:
# if GQA or MQA, repeat the key/value heads to reach the number of query heads.
heads_per_kv_head = attn.heads // kv_heads
key = torch.repeat_interleave(key, heads_per_kv_head, dim=1)
value = torch.repeat_interleave(value, heads_per_kv_head, dim=1)
key = torch.repeat_interleave(key, heads_per_kv_head, dim=1, output_size=key.shape[1] * heads_per_kv_head)
value = torch.repeat_interleave(
value, heads_per_kv_head, dim=1, output_size=value.shape[1] * heads_per_kv_head
)
if attn.norm_q is not None:
query = attn.norm_q(query)
@@ -190,7 +190,7 @@ class DCUpBlock2d(nn.Module):
x = F.pixel_shuffle(x, self.factor)
if self.shortcut:
y = hidden_states.repeat_interleave(self.repeats, dim=1)
y = hidden_states.repeat_interleave(self.repeats, dim=1, output_size=hidden_states.shape[1] * self.repeats)
y = F.pixel_shuffle(y, self.factor)
hidden_states = x + y
else:
@@ -361,7 +361,9 @@ class Decoder(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.in_shortcut:
x = hidden_states.repeat_interleave(self.in_shortcut_repeats, dim=1)
x = hidden_states.repeat_interleave(
self.in_shortcut_repeats, dim=1, output_size=hidden_states.shape[1] * self.in_shortcut_repeats
)
hidden_states = self.conv_in(hidden_states) + x
else:
hidden_states = self.conv_in(hidden_states)
@@ -103,7 +103,7 @@ class AllegroTemporalConvLayer(nn.Module):
if self.down_sample:
identity = hidden_states[:, :, ::2]
elif self.up_sample:
identity = hidden_states.repeat_interleave(2, dim=2)
identity = hidden_states.repeat_interleave(2, dim=2, output_size=hidden_states.shape[2] * 2)
else:
identity = hidden_states
@@ -426,7 +426,9 @@ class FourierFeatures(nn.Module):
w = w.repeat(num_channels)[None, :, None, None, None] # [1, num_channels * num_freqs, 1, 1, 1]
# Interleaved repeat of input channels to match w
h = inputs.repeat_interleave(num_freqs, dim=1) # [B, C * num_freqs, T, H, W]
h = inputs.repeat_interleave(
num_freqs, dim=1, output_size=inputs.shape[1] * num_freqs
) # [B, C * num_freqs, T, H, W]
# Scale channels by frequency.
h = w * h
@@ -687,7 +687,7 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
emb = emb.repeat_interleave(sample_num_frames, dim=0)
emb = emb.repeat_interleave(sample_num_frames, dim=0, output_size=emb.shape[0] * sample_num_frames)
# 2. pre-process
batch_size, channels, num_frames, height, width = sample.shape
+8 -3
View File
@@ -139,7 +139,9 @@ def get_3d_sincos_pos_embed(
# 3. Concat
pos_embed_spatial = pos_embed_spatial[None, :, :]
pos_embed_spatial = pos_embed_spatial.repeat_interleave(temporal_size, dim=0) # [T, H*W, D // 4 * 3]
pos_embed_spatial = pos_embed_spatial.repeat_interleave(
temporal_size, dim=0, output_size=pos_embed_spatial.shape[0] * temporal_size
) # [T, H*W, D // 4 * 3]
pos_embed_temporal = pos_embed_temporal[:, None, :]
pos_embed_temporal = pos_embed_temporal.repeat_interleave(
@@ -1152,10 +1154,13 @@ def get_1d_rotary_pos_embed(
/ linear_factor
) # [D/2]
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
is_npu = freqs.device.type == "npu"
if is_npu:
freqs = freqs.float()
if use_real and repeat_interleave_real:
# flux, hunyuan-dit, cogvideox
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
return freqs_cos, freqs_sin
elif use_real:
# stable audio, allegro
@@ -227,13 +227,17 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
# Prepare text embeddings for spatial block
# batch_size num_tokens hidden_size -> (batch_size * num_frame) num_tokens hidden_size
encoder_hidden_states = self.caption_projection(encoder_hidden_states) # 3 120 1152
encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave(num_frame, dim=0).view(
-1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1]
)
encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave(
num_frame, dim=0, output_size=encoder_hidden_states.shape[0] * num_frame
).view(-1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1])
# Prepare timesteps for spatial and temporal block
timestep_spatial = timestep.repeat_interleave(num_frame, dim=0).view(-1, timestep.shape[-1])
timestep_temp = timestep.repeat_interleave(num_patches, dim=0).view(-1, timestep.shape[-1])
timestep_spatial = timestep.repeat_interleave(
num_frame, dim=0, output_size=timestep.shape[0] * num_frame
).view(-1, timestep.shape[-1])
timestep_temp = timestep.repeat_interleave(
num_patches, dim=0, output_size=timestep.shape[0] * num_patches
).view(-1, timestep.shape[-1])
# Spatial and temporal transformer blocks
for i, (spatial_block, temp_block) in enumerate(
@@ -299,7 +303,9 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
).permute(0, 2, 1, 3)
hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1])
embedded_timestep = embedded_timestep.repeat_interleave(num_frame, dim=0).view(-1, embedded_timestep.shape[-1])
embedded_timestep = embedded_timestep.repeat_interleave(
num_frame, dim=0, output_size=embedded_timestep.shape[0] * num_frame
).view(-1, embedded_timestep.shape[-1])
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation
@@ -353,7 +353,11 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
attention_mask = attention_mask.repeat_interleave(
self.config.num_attention_heads,
dim=0,
output_size=attention_mask.shape[0] * self.config.num_attention_heads,
)
if self.norm_in is not None:
hidden_states = self.norm_in(hidden_states)
@@ -23,6 +23,7 @@ from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention import FeedForward
from ..attention_processor import Attention
from ..cache_utils import CacheMixin
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
@@ -126,7 +127,8 @@ class CogView4AttnProcessor:
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
batch_size, image_seq_length, embed_dim = hidden_states.shape
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
# 1. QKV projections
@@ -156,6 +158,15 @@ class CogView4AttnProcessor:
)
# 4. Attention
if attention_mask is not None:
text_attention_mask = attention_mask.float().to(query.device)
actual_text_seq_length = text_attention_mask.size(1)
new_attention_mask = torch.zeros((batch_size, text_seq_length + image_seq_length), device=query.device)
new_attention_mask[:, :actual_text_seq_length] = text_attention_mask
new_attention_mask = new_attention_mask.unsqueeze(2)
attention_mask_matrix = new_attention_mask @ new_attention_mask.transpose(1, 2)
attention_mask = (attention_mask_matrix > 0).unsqueeze(1).to(query.dtype)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
@@ -203,6 +214,8 @@ class CogView4TransformerBlock(nn.Module):
encoder_hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
# 1. Timestep conditioning
(
@@ -223,6 +236,8 @@ class CogView4TransformerBlock(nn.Module):
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
**kwargs,
)
hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1)
encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1)
@@ -289,7 +304,7 @@ class CogView4RotaryPosEmbed(nn.Module):
return (freqs.cos(), freqs.sin())
class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
r"""
Args:
patch_size (`int`, defaults to `2`):
@@ -386,6 +401,8 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
crop_coords: torch.Tensor,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
@@ -421,11 +438,11 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
block, hidden_states, encoder_hidden_states, temb, image_rotary_emb
block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs
)
else:
hidden_states, encoder_hidden_states = block(
hidden_states, encoder_hidden_states, temb, image_rotary_emb
hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs
)
# 4. Output norm & projection
@@ -441,6 +441,14 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
# 5. Output norm, projection & unpatchify
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
# Move the shift and scale tensors to the same device as hidden_states.
# When using multi-GPU inference via accelerate these will be on the
# first device rather than the last device, which hidden_states ends up
# on.
shift = shift.to(hidden_states.device)
scale = scale.to(hidden_states.device)
hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
hidden_states = self.proj_out(hidden_states)
@@ -638,8 +638,10 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames)
encoder_hidden_states = encoder_hidden_states.repeat_interleave(
num_frames, dim=0, output_size=encoder_hidden_states.shape[0] * num_frames
)
# 2. pre-process
sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
+2 -2
View File
@@ -592,7 +592,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
# 3. time + FPS embeddings.
emb = t_emb + fps_emb
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames)
# 4. context embeddings.
# The context embeddings consist of both text embeddings from the input prompt
@@ -620,7 +620,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
image_emb = self.context_embedding(image_embeddings)
image_emb = image_emb.view(-1, self.config.in_channels, self.config.cross_attention_dim)
context_emb = torch.cat([context_emb, image_emb], dim=1)
context_emb = context_emb.repeat_interleave(repeats=num_frames, dim=0)
context_emb = context_emb.repeat_interleave(num_frames, dim=0, output_size=context_emb.shape[0] * num_frames)
image_latents = image_latents.permute(0, 2, 1, 3, 4).reshape(
image_latents.shape[0] * image_latents.shape[2],
@@ -2059,7 +2059,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
aug_emb = self.add_embedding(add_embeds)
emb = emb if aug_emb is None else emb + aug_emb
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames)
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
if "image_embeds" not in added_cond_kwargs:
@@ -2068,7 +2068,10 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
)
image_embeds = added_cond_kwargs.get("image_embeds")
image_embeds = self.encoder_hid_proj(image_embeds)
image_embeds = [image_embed.repeat_interleave(repeats=num_frames, dim=0) for image_embed in image_embeds]
image_embeds = [
image_embed.repeat_interleave(num_frames, dim=0, output_size=image_embed.shape[0] * num_frames)
for image_embed in image_embeds
]
encoder_hidden_states = (encoder_hidden_states, image_embeds)
# 2. pre-process
@@ -431,9 +431,11 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
sample = sample.flatten(0, 1)
# Repeat the embeddings num_video_frames times
# emb: [batch, channels] -> [batch * frames, channels]
emb = emb.repeat_interleave(num_frames, dim=0)
emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames)
# encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
encoder_hidden_states = encoder_hidden_states.repeat_interleave(
num_frames, dim=0, output_size=encoder_hidden_states.shape[0] * num_frames
)
# 2. pre-process
sample = self.conv_in(sample)
+6 -6
View File
@@ -154,7 +154,7 @@ else:
"CogVideoXFunControlPipeline",
]
_import_structure["cogview3"] = ["CogView3PlusPipeline"]
_import_structure["cogview4"] = ["CogView4Pipeline"]
_import_structure["cogview4"] = ["CogView4Pipeline", "CogView4ControlPipeline"]
_import_structure["consisid"] = ["ConsisIDPipeline"]
_import_structure["controlnet"].extend(
[
@@ -265,8 +265,8 @@ else:
)
_import_structure["latte"] = ["LattePipeline"]
_import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline"]
_import_structure["lumina"] = ["LuminaText2ImgPipeline"]
_import_structure["lumina2"] = ["Lumina2Text2ImgPipeline"]
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
_import_structure["marigold"].extend(
[
"MarigoldDepthPipeline",
@@ -511,7 +511,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
CogVideoXVideoToVideoPipeline,
)
from .cogview3 import CogView3PlusPipeline
from .cogview4 import CogView4Pipeline
from .cogview4 import CogView4ControlPipeline, CogView4Pipeline
from .consisid import ConsisIDPipeline
from .controlnet import (
BlipDiffusionControlNetPipeline,
@@ -619,8 +619,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LEditsPPPipelineStableDiffusionXL,
)
from .ltx import LTXImageToVideoPipeline, LTXPipeline
from .lumina import LuminaText2ImgPipeline
from .lumina2 import Lumina2Text2ImgPipeline
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
from .marigold import (
MarigoldDepthPipeline,
MarigoldIntrinsicsPipeline,
+6 -5
View File
@@ -22,7 +22,7 @@ from ..models.controlnets import ControlNetUnionModel
from ..utils import is_sentencepiece_available
from .aura_flow import AuraFlowPipeline
from .cogview3 import CogView3PlusPipeline
from .cogview4 import CogView4Pipeline
from .cogview4 import CogView4ControlPipeline, CogView4Pipeline
from .controlnet import (
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline,
@@ -69,8 +69,8 @@ from .kandinsky2_2 import (
)
from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
from .lumina import LuminaText2ImgPipeline
from .lumina2 import Lumina2Text2ImgPipeline
from .lumina import LuminaPipeline
from .lumina2 import Lumina2Pipeline
from .pag import (
HunyuanDiTPAGPipeline,
PixArtSigmaPAGPipeline,
@@ -141,10 +141,11 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("flux", FluxPipeline),
("flux-control", FluxControlPipeline),
("flux-controlnet", FluxControlNetPipeline),
("lumina", LuminaText2ImgPipeline),
("lumina2", Lumina2Text2ImgPipeline),
("lumina", LuminaPipeline),
("lumina2", Lumina2Pipeline),
("cogview3", CogView3PlusPipeline),
("cogview4", CogView4Pipeline),
("cogview4-control", CogView4ControlPipeline),
]
)
@@ -23,6 +23,7 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_cogview4"] = ["CogView4Pipeline"]
_import_structure["pipeline_cogview4_control"] = ["CogView4ControlPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
@@ -31,6 +32,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .pipeline_cogview4 import CogView4Pipeline
from .pipeline_cogview4_control import CogView4ControlPipeline
else:
import sys
@@ -389,14 +389,18 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@property
def attention_kwargs(self):
return self._attention_kwargs
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -533,6 +537,7 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
# Default call parameters
@@ -610,6 +615,7 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
if self.interrupt:
continue
self._current_timestep = t
latent_model_input = latents.to(transformer_dtype)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
@@ -661,6 +667,8 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if not output_type == "latent":
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
image = self.vae.decode(latents, return_dict=False, generator=generator)[0]
@@ -0,0 +1,727 @@
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import AutoTokenizer, GlmModel
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...models import AutoencoderKL, CogView4Transformer2DModel
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from .pipeline_output import CogView4PipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from diffusers import CogView4ControlPipeline
>>> pipe = CogView4ControlPipeline.from_pretrained("THUDM/CogView4-6B-Control", torch_dtype=torch.bfloat16)
>>> control_image = load_image(
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
... )
>>> prompt = "A bird in space"
>>> image = pipe(prompt, control_image=control_image, height=1024, width=1024, guidance_scale=3.5).images[0]
>>> image.save("cogview4-control.png")
```
"""
# Copied from diffusers.pipelines.cogview4.pipeline_cogview4.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
base_shift: float = 0.25,
max_shift: float = 0.75,
) -> float:
m = (image_seq_len / base_seq_len) ** 0.5
mu = m * max_shift + base_shift
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class CogView4ControlPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using CogView4.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`GLMModel`]):
Frozen text-encoder. CogView4 uses [glm-4-9b-hf](https://huggingface.co/THUDM/glm-4-9b-hf).
tokenizer (`PreTrainedTokenizer`):
Tokenizer of class
[PreTrainedTokenizer](https://huggingface.co/docs/transformers/main/en/main_classes/tokenizer#transformers.PreTrainedTokenizer).
transformer ([`CogView4Transformer2DModel`]):
A text conditioned `CogView4Transformer2DModel` to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
"""
_optional_components = []
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
tokenizer: AutoTokenizer,
text_encoder: GlmModel,
vae: AutoencoderKL,
transformer: CogView4Transformer2DModel,
scheduler: FlowMatchEulerDiscreteScheduler,
):
super().__init__()
self.register_modules(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.cogview4.pipeline_cogview4.CogView4Pipeline._get_glm_embeds
def _get_glm_embeds(
self,
prompt: Union[str, List[str]] = None,
max_sequence_length: int = 1024,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
text_inputs = self.tokenizer(
prompt,
padding="longest", # not use max length
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
current_length = text_input_ids.shape[1]
pad_length = (16 - (current_length % 16)) % 16
if pad_length > 0:
pad_ids = torch.full(
(text_input_ids.shape[0], pad_length),
fill_value=self.tokenizer.pad_token_id,
dtype=text_input_ids.dtype,
device=text_input_ids.device,
)
text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
prompt_embeds = self.text_encoder(
text_input_ids.to(self.text_encoder.device), output_hidden_states=True
).hidden_states[-2]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
return prompt_embeds
# Copied from diffusers.pipelines.cogview4.pipeline_cogview4.CogView4Pipeline.encode_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
do_classifier_free_guidance: bool = True,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 1024,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
Whether to use classifier free guidance or not.
num_images_per_prompt (`int`, *optional*, defaults to 1):
Number of images that should be generated per prompt. torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
device: (`torch.device`, *optional*):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
max_sequence_length (`int`, defaults to `1024`):
Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds = self._get_glm_embeds(prompt, max_sequence_length, device, dtype)
seq_len = prompt_embeds.size(1)
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds = self._get_glm_embeds(negative_prompt, max_sequence_length, device, dtype)
seq_len = negative_prompt_embeds.size(1)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds, negative_prompt_embeds
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
if latents is not None:
return latents.to(device)
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
return latents
def prepare_image(
self,
image,
width,
height,
batch_size,
num_images_per_prompt,
device,
dtype,
do_classifier_free_guidance=False,
guess_mode=False,
):
if isinstance(image, torch.Tensor):
pass
else:
image = self.image_processor.preprocess(image, height=height, width=width)
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0, output_size=image.shape[0] * repeat_by)
image = image.to(device=device, dtype=dtype)
if do_classifier_free_guidance and not guess_mode:
image = torch.cat([image] * 2)
return image
def check_inputs(
self,
prompt,
height,
width,
negative_prompt,
callback_on_step_end_tensor_inputs,
prompt_embeds=None,
negative_prompt_embeds=None,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
@property
def guidance_scale(self):
return self._guidance_scale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def num_timesteps(self):
return self._num_timesteps
@property
def attention_kwargs(self):
return self._attention_kwargs
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
control_image: PipelineImageInput = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 5.0,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
original_size: Optional[Tuple[int, int]] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
output_type: str = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 1024,
) -> Union[CogView4PipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. If not provided, it is set to 1024.
width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. If not provided it is set to 1024.
num_inference_steps (`int`, *optional*, defaults to `50`):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to `5.0`):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to `1`):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int`, defaults to `224`):
Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
Examples:
Returns:
[`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] or `tuple`:
[`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
height = height or self.transformer.config.sample_size * self.vae_scale_factor
width = width or self.transformer.config.sample_size * self.vae_scale_factor
original_size = original_size or (height, width)
target_size = (height, width)
# Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
negative_prompt,
callback_on_step_end_tensor_inputs,
prompt_embeds,
negative_prompt_embeds,
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
# Default call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
negative_prompt,
self.do_classifier_free_guidance,
num_images_per_prompt=num_images_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
max_sequence_length=max_sequence_length,
device=device,
)
# Prepare latents
latent_channels = self.transformer.config.in_channels // 2
control_image = self.prepare_image(
image=control_image,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=self.vae.dtype,
)
height, width = control_image.shape[-2:]
vae_shift_factor = 0
control_image = self.vae.encode(control_image).latent_dist.sample()
control_image = (control_image - vae_shift_factor) * self.vae.config.scaling_factor
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
latent_channels,
height,
width,
torch.float32,
device,
generator,
latents,
)
# Prepare additional timestep conditions
original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype, device=device)
target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=device)
crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype, device=device)
original_size = original_size.repeat(batch_size * num_images_per_prompt, 1)
target_size = target_size.repeat(batch_size * num_images_per_prompt, 1)
crops_coords_top_left = crops_coords_top_left.repeat(batch_size * num_images_per_prompt, 1)
# Prepare timesteps
image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // (
self.transformer.config.patch_size**2
)
timesteps = (
np.linspace(self.scheduler.config.num_train_timesteps, 1.0, num_inference_steps)
if timesteps is None
else np.array(timesteps)
)
timesteps = timesteps.astype(np.int64).astype(np.float32)
sigmas = timesteps / self.scheduler.config.num_train_timesteps if sigmas is None else sigmas
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("base_shift", 0.25),
self.scheduler.config.get("max_shift", 0.75),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu
)
self._num_timesteps = len(timesteps)
# Denoising loop
transformer_dtype = self.transformer.dtype
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
latent_model_input = torch.cat([latents, control_image], dim=1).to(transformer_dtype)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0])
noise_pred_cond = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
original_size=original_size,
target_size=target_size,
crop_coords=crops_coords_top_left,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=negative_prompt_embeds,
timestep=timestep,
original_size=original_size,
target_size=target_size,
crop_coords=crops_coords_top_left,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
else:
noise_pred = noise_pred_cond
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
# call the callback, if provided
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, self.scheduler.sigmas[i], callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if not output_type == "latent":
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
image = self.vae.decode(latents, return_dict=False, generator=generator)[0]
else:
image = latents
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return CogView4PipelineOutput(images=image)
+2 -2
View File
@@ -22,7 +22,7 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_lumina"] = ["LuminaText2ImgPipeline"]
_import_structure["pipeline_lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -32,7 +32,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_lumina import LuminaText2ImgPipeline
from .pipeline_lumina import LuminaPipeline, LuminaText2ImgPipeline
else:
import sys
@@ -30,6 +30,7 @@ from ...models.transformers.lumina_nextdit2d import LuminaNextDiT2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
BACKENDS_MAPPING,
deprecate,
is_bs4_available,
is_ftfy_available,
is_torch_xla_available,
@@ -60,11 +61,9 @@ EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import LuminaText2ImgPipeline
>>> from diffusers import LuminaPipeline
>>> pipe = LuminaText2ImgPipeline.from_pretrained(
... "Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16
... )
>>> pipe = LuminaPipeline.from_pretrained("Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16)
>>> # Enable memory optimizations.
>>> pipe.enable_model_cpu_offload()
@@ -134,7 +133,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
class LuminaText2ImgPipeline(DiffusionPipeline):
class LuminaPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using Lumina-T2I.
@@ -932,3 +931,23 @@ class LuminaText2ImgPipeline(DiffusionPipeline):
return (image,)
return ImagePipelineOutput(images=image)
class LuminaText2ImgPipeline(LuminaPipeline):
def __init__(
self,
transformer: LuminaNextDiT2DModel,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
text_encoder: GemmaPreTrainedModel,
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
):
deprecation_message = "`LuminaText2ImgPipeline` has been renamed to `LuminaPipeline` and will be removed in a future version. Please use `LuminaPipeline` instead."
deprecate("diffusers.pipelines.lumina.pipeline_lumina.LuminaText2ImgPipeline", "0.34", deprecation_message)
super().__init__(
transformer=transformer,
scheduler=scheduler,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
)
+2 -2
View File
@@ -22,7 +22,7 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_lumina2"] = ["Lumina2Text2ImgPipeline"]
_import_structure["pipeline_lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -32,7 +32,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_lumina2 import Lumina2Text2ImgPipeline
from .pipeline_lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
else:
import sys
@@ -25,6 +25,7 @@ from ...models import AutoencoderKL
from ...models.transformers.transformer_lumina2 import Lumina2Transformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
deprecate,
is_torch_xla_available,
logging,
replace_example_docstring,
@@ -47,9 +48,9 @@ EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import Lumina2Text2ImgPipeline
>>> from diffusers import Lumina2Pipeline
>>> pipe = Lumina2Text2ImgPipeline.from_pretrained("Alpha-VLLM/Lumina-Image-2.0", torch_dtype=torch.bfloat16)
>>> pipe = Lumina2Pipeline.from_pretrained("Alpha-VLLM/Lumina-Image-2.0", torch_dtype=torch.bfloat16)
>>> # Enable memory optimizations.
>>> pipe.enable_model_cpu_offload()
@@ -133,7 +134,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
class Lumina2Text2ImgPipeline(DiffusionPipeline, Lumina2LoraLoaderMixin):
class Lumina2Pipeline(DiffusionPipeline, Lumina2LoraLoaderMixin):
r"""
Pipeline for text-to-image generation using Lumina-T2I.
@@ -767,3 +768,23 @@ class Lumina2Text2ImgPipeline(DiffusionPipeline, Lumina2LoraLoaderMixin):
return (image,)
return ImagePipelineOutput(images=image)
class Lumina2Text2ImgPipeline(Lumina2Pipeline):
def __init__(
self,
transformer: Lumina2Transformer2DModel,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
text_encoder: Gemma2PreTrainedModel,
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
):
deprecation_message = "`Lumina2Text2ImgPipeline` has been renamed to `Lumina2Pipeline` and will be removed in a future version. Please use `Lumina2Pipeline` instead."
deprecate("diffusers.pipelines.lumina2.pipeline_lumina2.Lumina2Text2ImgPipeline", "0.34", deprecation_message)
super().__init__(
transformer=transformer,
scheduler=scheduler,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
)
+5 -3
View File
@@ -1610,7 +1610,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
expected_modules.add(name)
optional_parameters.remove(name)
return expected_modules, optional_parameters
return sorted(expected_modules), sorted(optional_parameters)
@classmethod
def _get_signature_types(cls):
@@ -1652,10 +1652,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters
}
if set(components.keys()) != expected_modules:
actual = sorted(set(components.keys()))
expected = sorted(expected_modules)
if actual != expected:
raise ValueError(
f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected"
f" {expected_modules} to be defined, but {components.keys()} are defined."
f" {expected} to be defined, but {actual} are defined."
)
return components
+26 -10
View File
@@ -109,14 +109,30 @@ def prompt_clean(text):
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
encoder_output: torch.Tensor,
latents_mean: torch.Tensor,
latents_std: torch.Tensor,
generator: Optional[torch.Generator] = None,
sample_mode: str = "sample",
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
encoder_output.latent_dist.mean = (encoder_output.latent_dist.mean - latents_mean) * latents_std
encoder_output.latent_dist.logvar = torch.clamp(
(encoder_output.latent_dist.logvar - latents_mean) * latents_std, -30.0, 20.0
)
encoder_output.latent_dist.std = torch.exp(0.5 * encoder_output.latent_dist.logvar)
encoder_output.latent_dist.var = torch.exp(encoder_output.latent_dist.logvar)
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
encoder_output.latent_dist.mean = (encoder_output.latent_dist.mean - latents_mean) * latents_std
encoder_output.latent_dist.logvar = torch.clamp(
(encoder_output.latent_dist.logvar - latents_mean) * latents_std, -30.0, 20.0
)
encoder_output.latent_dist.std = torch.exp(0.5 * encoder_output.latent_dist.logvar)
encoder_output.latent_dist.var = torch.exp(encoder_output.latent_dist.logvar)
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
return (encoder_output.latents - latents_mean) * latents_std
else:
raise AttributeError("Could not access latents of provided encoder_output")
@@ -385,13 +401,6 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
)
video_condition = video_condition.to(device=device, dtype=dtype)
if isinstance(generator, list):
latent_condition = [retrieve_latents(self.vae.encode(video_condition), g) for g in generator]
latents = latent_condition = torch.cat(latent_condition)
else:
latent_condition = retrieve_latents(self.vae.encode(video_condition), generator)
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
@@ -401,7 +410,14 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
latents.device, latents.dtype
)
latent_condition = (latent_condition - latents_mean) * latents_std
if isinstance(generator, list):
latent_condition = [
retrieve_latents(self.vae.encode(video_condition), latents_mean, latents_std, g) for g in generator
]
latent_condition = torch.cat(latent_condition)
else:
latent_condition = retrieve_latents(self.vae.encode(video_condition), latents_mean, latents_std, generator)
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
mask_lat_size[:, :, list(range(1, num_frames))] = 0
@@ -23,7 +23,14 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union
from packaging import version
from ...utils import get_module_from_name, is_torch_available, is_torch_version, is_torchao_available, logging
from ...utils import (
get_module_from_name,
is_torch_available,
is_torch_version,
is_torchao_available,
is_torchao_version,
logging,
)
from ..base import DiffusersQuantizer
@@ -62,6 +69,43 @@ if is_torchao_available():
from torchao.quantization import quantize_
def _update_torch_safe_globals():
safe_globals = [
(torch.uint1, "torch.uint1"),
(torch.uint2, "torch.uint2"),
(torch.uint3, "torch.uint3"),
(torch.uint4, "torch.uint4"),
(torch.uint5, "torch.uint5"),
(torch.uint6, "torch.uint6"),
(torch.uint7, "torch.uint7"),
]
try:
from torchao.dtypes import NF4Tensor
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
from torchao.dtypes.uintx.uint4_layout import UInt4Tensor
from torchao.dtypes.uintx.uintx_layout import UintxAQTTensorImpl, UintxTensor
safe_globals.extend([UintxTensor, UInt4Tensor, UintxAQTTensorImpl, Float8AQTTensorImpl, NF4Tensor])
except (ImportError, ModuleNotFoundError) as e:
logger.warning(
"Unable to import `torchao` Tensor objects. This may affect loading checkpoints serialized with `torchao`"
)
logger.debug(e)
finally:
torch.serialization.add_safe_globals(safe_globals=safe_globals)
if (
is_torch_available()
and is_torch_version(">=", "2.6.0")
and is_torchao_available()
and is_torchao_version(">=", "0.7.0")
):
_update_torch_safe_globals()
logger = logging.get_logger(__name__)
+1
View File
@@ -94,6 +94,7 @@ from .import_utils import (
is_torch_xla_available,
is_torch_xla_version,
is_torchao_available,
is_torchao_version,
is_torchsde_available,
is_torchvision_available,
is_transformers_available,
+11
View File
@@ -56,3 +56,14 @@ USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version
if USE_PEFT_BACKEND and _CHECK_PEFT:
dep_version_check("peft")
DECODE_ENDPOINT_SD_V1 = "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/"
DECODE_ENDPOINT_SD_XL = "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud/"
DECODE_ENDPOINT_FLUX = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/"
DECODE_ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/"
ENCODE_ENDPOINT_SD_V1 = "https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/"
ENCODE_ENDPOINT_SD_XL = "https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud/"
ENCODE_ENDPOINT_FLUX = "https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/"
@@ -362,6 +362,21 @@ class CogView3PlusPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class CogView4ControlPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class CogView4Pipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -1232,6 +1247,21 @@ class LTXPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class Lumina2Pipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class Lumina2Text2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -1247,6 +1277,21 @@ class Lumina2Text2ImgPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class LuminaPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class LuminaText2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
+76 -234
View File
@@ -25,7 +25,6 @@ from types import ModuleType
from typing import Any, Union
from huggingface_hub.utils import is_jinja_available # noqa: F401
from packaging import version
from packaging.version import Version, parse
from . import logging
@@ -52,36 +51,30 @@ DIFFUSERS_SLOW_IMPORT = DIFFUSERS_SLOW_IMPORT in ENV_VARS_TRUE_VALUES
STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}
_torch_version = "N/A"
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
_torch_available = importlib.util.find_spec("torch") is not None
if _torch_available:
_is_google_colab = "google.colab" in sys.modules or any(k.startswith("COLAB_") for k in os.environ)
def _is_package_available(pkg_name: str):
pkg_exists = importlib.util.find_spec(pkg_name) is not None
pkg_version = "N/A"
if pkg_exists:
try:
_torch_version = importlib_metadata.version("torch")
logger.info(f"PyTorch version {_torch_version} available.")
except importlib_metadata.PackageNotFoundError:
_torch_available = False
pkg_version = importlib_metadata.version(pkg_name)
logger.debug(f"Successfully imported {pkg_name} version {pkg_version}")
except (ImportError, importlib_metadata.PackageNotFoundError):
pkg_exists = False
return pkg_exists, pkg_version
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
_torch_available, _torch_version = _is_package_available("torch")
else:
logger.info("Disabling PyTorch because USE_TORCH is set")
_torch_available = False
_torch_xla_available = importlib.util.find_spec("torch_xla") is not None
if _torch_xla_available:
try:
_torch_xla_version = importlib_metadata.version("torch_xla")
logger.info(f"PyTorch XLA version {_torch_xla_version} available.")
except ImportError:
_torch_xla_available = False
# check whether torch_npu is available
_torch_npu_available = importlib.util.find_spec("torch_npu") is not None
if _torch_npu_available:
try:
_torch_npu_version = importlib_metadata.version("torch_npu")
logger.info(f"torch_npu version {_torch_npu_version} available.")
except ImportError:
_torch_npu_available = False
_jax_version = "N/A"
_flax_version = "N/A"
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
@@ -97,47 +90,12 @@ else:
_flax_available = False
if USE_SAFETENSORS in ENV_VARS_TRUE_AND_AUTO_VALUES:
_safetensors_available = importlib.util.find_spec("safetensors") is not None
if _safetensors_available:
try:
_safetensors_version = importlib_metadata.version("safetensors")
logger.info(f"Safetensors version {_safetensors_version} available.")
except importlib_metadata.PackageNotFoundError:
_safetensors_available = False
_safetensors_available, _safetensors_version = _is_package_available("safetensors")
else:
logger.info("Disabling Safetensors because USE_TF is set")
_safetensors_available = False
_transformers_available = importlib.util.find_spec("transformers") is not None
try:
_transformers_version = importlib_metadata.version("transformers")
logger.debug(f"Successfully imported transformers version {_transformers_version}")
except importlib_metadata.PackageNotFoundError:
_transformers_available = False
_hf_hub_available = importlib.util.find_spec("huggingface_hub") is not None
try:
_hf_hub_version = importlib_metadata.version("huggingface_hub")
logger.debug(f"Successfully imported huggingface_hub version {_hf_hub_version}")
except importlib_metadata.PackageNotFoundError:
_hf_hub_available = False
_inflect_available = importlib.util.find_spec("inflect") is not None
try:
_inflect_version = importlib_metadata.version("inflect")
logger.debug(f"Successfully imported inflect version {_inflect_version}")
except importlib_metadata.PackageNotFoundError:
_inflect_available = False
_unidecode_available = importlib.util.find_spec("unidecode") is not None
try:
_unidecode_version = importlib_metadata.version("unidecode")
logger.debug(f"Successfully imported unidecode version {_unidecode_version}")
except importlib_metadata.PackageNotFoundError:
_unidecode_available = False
_onnxruntime_version = "N/A"
_onnx_available = importlib.util.find_spec("onnxruntime") is not None
if _onnx_available:
@@ -186,85 +144,6 @@ try:
except importlib_metadata.PackageNotFoundError:
_opencv_available = False
_scipy_available = importlib.util.find_spec("scipy") is not None
try:
_scipy_version = importlib_metadata.version("scipy")
logger.debug(f"Successfully imported scipy version {_scipy_version}")
except importlib_metadata.PackageNotFoundError:
_scipy_available = False
_librosa_available = importlib.util.find_spec("librosa") is not None
try:
_librosa_version = importlib_metadata.version("librosa")
logger.debug(f"Successfully imported librosa version {_librosa_version}")
except importlib_metadata.PackageNotFoundError:
_librosa_available = False
_accelerate_available = importlib.util.find_spec("accelerate") is not None
try:
_accelerate_version = importlib_metadata.version("accelerate")
logger.debug(f"Successfully imported accelerate version {_accelerate_version}")
except importlib_metadata.PackageNotFoundError:
_accelerate_available = False
_xformers_available = importlib.util.find_spec("xformers") is not None
try:
_xformers_version = importlib_metadata.version("xformers")
if _torch_available:
_torch_version = importlib_metadata.version("torch")
if version.Version(_torch_version) < version.Version("1.12"):
raise ValueError("xformers is installed in your environment and requires PyTorch >= 1.12")
logger.debug(f"Successfully imported xformers version {_xformers_version}")
except importlib_metadata.PackageNotFoundError:
_xformers_available = False
_k_diffusion_available = importlib.util.find_spec("k_diffusion") is not None
try:
_k_diffusion_version = importlib_metadata.version("k_diffusion")
logger.debug(f"Successfully imported k-diffusion version {_k_diffusion_version}")
except importlib_metadata.PackageNotFoundError:
_k_diffusion_available = False
_note_seq_available = importlib.util.find_spec("note_seq") is not None
try:
_note_seq_version = importlib_metadata.version("note_seq")
logger.debug(f"Successfully imported note-seq version {_note_seq_version}")
except importlib_metadata.PackageNotFoundError:
_note_seq_available = False
_wandb_available = importlib.util.find_spec("wandb") is not None
try:
_wandb_version = importlib_metadata.version("wandb")
logger.debug(f"Successfully imported wandb version {_wandb_version }")
except importlib_metadata.PackageNotFoundError:
_wandb_available = False
_tensorboard_available = importlib.util.find_spec("tensorboard")
try:
_tensorboard_version = importlib_metadata.version("tensorboard")
logger.debug(f"Successfully imported tensorboard version {_tensorboard_version}")
except importlib_metadata.PackageNotFoundError:
_tensorboard_available = False
_compel_available = importlib.util.find_spec("compel")
try:
_compel_version = importlib_metadata.version("compel")
logger.debug(f"Successfully imported compel version {_compel_version}")
except importlib_metadata.PackageNotFoundError:
_compel_available = False
_ftfy_available = importlib.util.find_spec("ftfy") is not None
try:
_ftfy_version = importlib_metadata.version("ftfy")
logger.debug(f"Successfully imported ftfy version {_ftfy_version}")
except importlib_metadata.PackageNotFoundError:
_ftfy_available = False
_bs4_available = importlib.util.find_spec("bs4") is not None
try:
# importlib metadata under different name
@@ -273,13 +152,6 @@ try:
except importlib_metadata.PackageNotFoundError:
_bs4_available = False
_torchsde_available = importlib.util.find_spec("torchsde") is not None
try:
_torchsde_version = importlib_metadata.version("torchsde")
logger.debug(f"Successfully imported torchsde version {_torchsde_version}")
except importlib_metadata.PackageNotFoundError:
_torchsde_available = False
_invisible_watermark_available = importlib.util.find_spec("imwatermark") is not None
try:
_invisible_watermark_version = importlib_metadata.version("invisible-watermark")
@@ -287,91 +159,42 @@ try:
except importlib_metadata.PackageNotFoundError:
_invisible_watermark_available = False
_torch_xla_available, _torch_xla_version = _is_package_available("torch_xla")
_torch_npu_available, _torch_npu_version = _is_package_available("torch_npu")
_transformers_available, _transformers_version = _is_package_available("transformers")
_hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub")
_inflect_available, _inflect_version = _is_package_available("inflect")
_unidecode_available, _unidecode_version = _is_package_available("unidecode")
_k_diffusion_available, _k_diffusion_version = _is_package_available("k_diffusion")
_note_seq_available, _note_seq_version = _is_package_available("note_seq")
_wandb_available, _wandb_version = _is_package_available("wandb")
_tensorboard_available, _tensorboard_version = _is_package_available("tensorboard")
_compel_available, _compel_version = _is_package_available("compel")
_sentencepiece_available, _sentencepiece_version = _is_package_available("sentencepiece")
_torchsde_available, _torchsde_version = _is_package_available("torchsde")
_peft_available, _peft_version = _is_package_available("peft")
_torchvision_available, _torchvision_version = _is_package_available("torchvision")
_matplotlib_available, _matplotlib_version = _is_package_available("matplotlib")
_timm_available, _timm_version = _is_package_available("timm")
_bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes")
_imageio_available, _imageio_version = _is_package_available("imageio")
_ftfy_available, _ftfy_version = _is_package_available("ftfy")
_scipy_available, _scipy_version = _is_package_available("scipy")
_librosa_available, _librosa_version = _is_package_available("librosa")
_accelerate_available, _accelerate_version = _is_package_available("accelerate")
_xformers_available, _xformers_version = _is_package_available("xformers")
_gguf_available, _gguf_version = _is_package_available("gguf")
_torchao_available, _torchao_version = _is_package_available("torchao")
_bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes")
_torchao_available, _torchao_version = _is_package_available("torchao")
_peft_available = importlib.util.find_spec("peft") is not None
try:
_peft_version = importlib_metadata.version("peft")
logger.debug(f"Successfully imported peft version {_peft_version}")
except importlib_metadata.PackageNotFoundError:
_peft_available = False
_torchvision_available = importlib.util.find_spec("torchvision") is not None
try:
_torchvision_version = importlib_metadata.version("torchvision")
logger.debug(f"Successfully imported torchvision version {_torchvision_version}")
except importlib_metadata.PackageNotFoundError:
_torchvision_available = False
_sentencepiece_available = importlib.util.find_spec("sentencepiece") is not None
try:
_sentencepiece_version = importlib_metadata.version("sentencepiece")
logger.info(f"Successfully imported sentencepiece version {_sentencepiece_version}")
except importlib_metadata.PackageNotFoundError:
_sentencepiece_available = False
_matplotlib_available = importlib.util.find_spec("matplotlib") is not None
try:
_matplotlib_version = importlib_metadata.version("matplotlib")
logger.debug(f"Successfully imported matplotlib version {_matplotlib_version}")
except importlib_metadata.PackageNotFoundError:
_matplotlib_available = False
_timm_available = importlib.util.find_spec("timm") is not None
if _timm_available:
try:
_timm_version = importlib_metadata.version("timm")
logger.info(f"Timm version {_timm_version} available.")
except importlib_metadata.PackageNotFoundError:
_timm_available = False
def is_timm_available():
return _timm_available
_bitsandbytes_available = importlib.util.find_spec("bitsandbytes") is not None
try:
_bitsandbytes_version = importlib_metadata.version("bitsandbytes")
logger.debug(f"Successfully imported bitsandbytes version {_bitsandbytes_version}")
except importlib_metadata.PackageNotFoundError:
_bitsandbytes_available = False
_is_google_colab = "google.colab" in sys.modules or any(k.startswith("COLAB_") for k in os.environ)
_imageio_available = importlib.util.find_spec("imageio") is not None
if _imageio_available:
try:
_imageio_version = importlib_metadata.version("imageio")
logger.debug(f"Successfully imported imageio version {_imageio_version}")
except importlib_metadata.PackageNotFoundError:
_imageio_available = False
_is_gguf_available = importlib.util.find_spec("gguf") is not None
if _is_gguf_available:
try:
_gguf_version = importlib_metadata.version("gguf")
logger.debug(f"Successfully import gguf version {_gguf_version}")
except importlib_metadata.PackageNotFoundError:
_is_gguf_available = False
_is_torchao_available = importlib.util.find_spec("torchao") is not None
if _is_torchao_available:
try:
_torchao_version = importlib_metadata.version("torchao")
logger.debug(f"Successfully import torchao version {_torchao_version}")
except importlib_metadata.PackageNotFoundError:
_is_torchao_available = False
_is_optimum_quanto_available = importlib.util.find_spec("optimum") is not None
if _is_optimum_quanto_available:
_optimum_quanto_available = importlib.util.find_spec("optimum") is not None
if _optimum_quanto_available:
try:
_optimum_quanto_version = importlib_metadata.version("optimum_quanto")
logger.debug(f"Successfully import optimum-quanto version {_optimum_quanto_version}")
except importlib_metadata.PackageNotFoundError:
_is_optimum_quanto_available = False
_optimum_quanto_available = False
def is_torch_available():
@@ -495,15 +318,19 @@ def is_imageio_available():
def is_gguf_available():
return _is_gguf_available
return _gguf_available
def is_torchao_available():
return _is_torchao_available
return _torchao_available
def is_optimum_quanto_available():
return _is_optimum_quanto_available
return _optimum_quanto_available
def is_timm_available():
return _timm_available
# docstyle-ignore
@@ -863,11 +690,26 @@ def is_gguf_version(operation: str, version: str):
version (`str`):
A version string
"""
if not _is_gguf_available:
if not _gguf_available:
return False
return compare_versions(parse(_gguf_version), operation, version)
def is_torchao_version(operation: str, version: str):
"""
Compares the current torchao version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _torchao_available:
return False
return compare_versions(parse(_torchao_version), operation, version)
def is_k_diffusion_version(operation: str, version: str):
"""
Compares the current k-diffusion version to a given reference with an operation.
@@ -893,7 +735,7 @@ def is_optimum_quanto_version(operation: str, version: str):
version (`str`):
A version string
"""
if not _is_optimum_quanto_available:
if not _optimum_quanto_available:
return False
return compare_versions(parse(_optimum_quanto_version), operation, version)
+97 -6
View File
@@ -55,7 +55,7 @@ def detect_image_type(data: bytes) -> str:
return "unknown"
def check_inputs(
def check_inputs_decode(
endpoint: str,
tensor: "torch.Tensor",
processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
@@ -89,7 +89,7 @@ def check_inputs(
)
def postprocess(
def postprocess_decode(
response: requests.Response,
processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
output_type: Literal["mp4", "pil", "pt"] = "pil",
@@ -142,7 +142,7 @@ def postprocess(
return output
def prepare(
def prepare_decode(
tensor: "torch.Tensor",
processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
do_scaling: bool = True,
@@ -293,7 +293,7 @@ def remote_decode(
standard_warn=False,
)
output_tensor_type = "binary"
check_inputs(
check_inputs_decode(
endpoint,
tensor,
processor,
@@ -309,7 +309,7 @@ def remote_decode(
height,
width,
)
kwargs = prepare(
kwargs = prepare_decode(
tensor=tensor,
processor=processor,
do_scaling=do_scaling,
@@ -324,7 +324,7 @@ def remote_decode(
response = requests.post(endpoint, **kwargs)
if not response.ok:
raise RuntimeError(response.json())
output = postprocess(
output = postprocess_decode(
response=response,
processor=processor,
output_type=output_type,
@@ -332,3 +332,94 @@ def remote_decode(
partial_postprocess=partial_postprocess,
)
return output
def check_inputs_encode(
endpoint: str,
image: Union["torch.Tensor", Image.Image],
scaling_factor: Optional[float] = None,
shift_factor: Optional[float] = None,
):
pass
def postprocess_encode(
response: requests.Response,
):
output_tensor = response.content
parameters = response.headers
shape = json.loads(parameters["shape"])
dtype = parameters["dtype"]
torch_dtype = DTYPE_MAP[dtype]
output_tensor = torch.frombuffer(bytearray(output_tensor), dtype=torch_dtype).reshape(shape)
return output_tensor
def prepare_encode(
image: Union["torch.Tensor", Image.Image],
scaling_factor: Optional[float] = None,
shift_factor: Optional[float] = None,
):
headers = {}
parameters = {}
if scaling_factor is not None:
parameters["scaling_factor"] = scaling_factor
if shift_factor is not None:
parameters["shift_factor"] = shift_factor
if isinstance(image, torch.Tensor):
data = safetensors.torch._tobytes(image, "tensor")
parameters["shape"] = list(image.shape)
parameters["dtype"] = str(image.dtype).split(".")[-1]
else:
buffer = io.BytesIO()
image.save(buffer, format="PNG")
data = buffer.getvalue()
return {"data": data, "params": parameters, "headers": headers}
def remote_encode(
endpoint: str,
image: Union["torch.Tensor", Image.Image],
scaling_factor: Optional[float] = None,
shift_factor: Optional[float] = None,
) -> "torch.Tensor":
"""
Hugging Face Hybrid Inference that allow running VAE encode remotely.
Args:
endpoint (`str`):
Endpoint for Remote Decode.
image (`torch.Tensor` or `PIL.Image.Image`):
Image to be encoded.
scaling_factor (`float`, *optional*):
Scaling is applied when passed e.g. [`latents * self.vae.config.scaling_factor`].
- SD v1: 0.18215
- SD XL: 0.13025
- Flux: 0.3611
If `None`, input must be passed with scaling applied.
shift_factor (`float`, *optional*):
Shift is applied when passed e.g. `latents - self.vae.config.shift_factor`.
- Flux: 0.1159
If `None`, input must be passed with scaling applied.
Returns:
output (`torch.Tensor`).
"""
check_inputs_encode(
endpoint,
image,
scaling_factor,
shift_factor,
)
kwargs = prepare_encode(
image=image,
scaling_factor=scaling_factor,
shift_factor=shift_factor,
)
response = requests.post(endpoint, **kwargs)
if not response.ok:
raise RuntimeError(response.json())
output = postprocess_encode(
response=response,
)
return output
+16
View File
@@ -101,6 +101,8 @@ if is_torch_available():
mps_backend_registered = hasattr(torch.backends, "mps")
torch_device = "mps" if (mps_backend_registered and torch.backends.mps.is_available()) else torch_device
from .torch_utils import get_torch_cuda_device_capability
def torch_all_close(a, b, *args, **kwargs):
if not is_torch_available():
@@ -282,6 +284,20 @@ def require_torch_gpu(test_case):
)
def require_torch_cuda_compatibility(expected_compute_capability):
def decorator(test_case):
if not torch.cuda.is_available():
return unittest.skip(test_case)
else:
current_compute_capability = get_torch_cuda_device_capability()
return unittest.skipUnless(
float(current_compute_capability) == float(expected_compute_capability),
"Test not supported for this compute capability.",
)
return decorator
# These decorators are for accelerator-specific behaviours that are not GPU-specific
def require_torch_accelerator(test_case):
"""Decorator marking a test that requires an accelerator backend and PyTorch."""
+2 -2
View File
@@ -1961,7 +1961,7 @@ class PeftLoraLoaderMixinTests:
no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)}
logger = logging.get_logger("diffusers.loaders.peft")
logger.setLevel(logging.INFO)
logger.setLevel(logging.WARNING)
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(no_op_state_dict)
@@ -1981,7 +1981,7 @@ class PeftLoraLoaderMixinTests:
prefix = "text_encoder_2"
logger = logging.get_logger("diffusers.loaders.lora_base")
logger.setLevel(logging.INFO)
logger.setLevel(logging.WARNING)
with CaptureLogger(logger) as cap_logger:
self.pipeline_class.load_lora_into_text_encoder(
+17 -5
View File
@@ -5,7 +5,13 @@ import numpy as np
import torch
from transformers import AutoTokenizer, GemmaConfig, GemmaForCausalLM
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, LuminaNextDiT2DModel, LuminaText2ImgPipeline
from diffusers import (
AutoencoderKL,
FlowMatchEulerDiscreteScheduler,
LuminaNextDiT2DModel,
LuminaPipeline,
LuminaText2ImgPipeline,
)
from diffusers.utils.testing_utils import (
backend_empty_cache,
numpy_cosine_similarity_distance,
@@ -17,8 +23,8 @@ from diffusers.utils.testing_utils import (
from ..test_pipelines_common import PipelineTesterMixin
class LuminaText2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTesterMixin):
pipeline_class = LuminaText2ImgPipeline
class LuminaPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
pipeline_class = LuminaPipeline
params = frozenset(
[
"prompt",
@@ -99,11 +105,17 @@ class LuminaText2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTesterM
def test_xformers_attention_forwardGenerator_pass(self):
pass
def test_deprecation_raises_warning(self):
with self.assertWarns(FutureWarning) as warning:
_ = LuminaText2ImgPipeline(**self.get_dummy_components()).to(torch_device)
warning_message = str(warning.warnings[0].message)
assert "renamed to `LuminaPipeline`" in warning_message
@slow
@require_torch_accelerator
class LuminaText2ImgPipelineSlowTests(unittest.TestCase):
pipeline_class = LuminaText2ImgPipeline
class LuminaPipelineSlowTests(unittest.TestCase):
pipeline_class = LuminaPipeline
repo_id = "Alpha-VLLM/Lumina-Next-SFT-diffusers"
def setUp(self):
@@ -6,15 +6,17 @@ from transformers import AutoTokenizer, Gemma2Config, Gemma2Model
from diffusers import (
AutoencoderKL,
FlowMatchEulerDiscreteScheduler,
Lumina2Pipeline,
Lumina2Text2ImgPipeline,
Lumina2Transformer2DModel,
)
from diffusers.utils.testing_utils import torch_device
from ..test_pipelines_common import PipelineTesterMixin
class Lumina2Text2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTesterMixin):
pipeline_class = Lumina2Text2ImgPipeline
class Lumina2PipelineFastTests(unittest.TestCase, PipelineTesterMixin):
pipeline_class = Lumina2Pipeline
params = frozenset(
[
"prompt",
@@ -115,3 +117,9 @@ class Lumina2Text2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTester
"output_type": "np",
}
return inputs
def test_deprecation_raises_warning(self):
with self.assertWarns(FutureWarning) as warning:
_ = Lumina2Text2ImgPipeline(**self.get_dummy_components()).to(torch_device)
warning_message = str(warning.warnings[0].message)
assert "renamed to `Lumina2Pipeline`" in warning_message
+102 -1
View File
@@ -19,7 +19,7 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible, variant_compatible_siblings
from diffusers.utils.testing_utils import torch_device
from diffusers.utils.testing_utils import require_torch_gpu, torch_device
class IsSafetensorsCompatibleTests(unittest.TestCase):
@@ -826,3 +826,104 @@ class ProgressBarTests(unittest.TestCase):
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
_ = pipe(**inputs)
self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
@require_torch_gpu
class PipelineDeviceAndDtypeStabilityTests(unittest.TestCase):
expected_pipe_device = torch.device("cuda:0")
expected_pipe_dtype = torch.float64
def get_dummy_components_image_generation(self):
cross_attention_dim = 8
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(4, 8),
layers_per_block=1,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=cross_attention_dim,
norm_num_groups=2,
)
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[4, 8],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
norm_num_groups=2,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=cross_attention_dim,
intermediate_size=16,
layer_norm_eps=1e-05,
num_attention_heads=2,
num_hidden_layers=2,
pad_token_id=1,
vocab_size=1000,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
components = {
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
"image_encoder": None,
}
return components
def test_deterministic_device(self):
components = self.get_dummy_components_image_generation()
pipe = StableDiffusionPipeline(**components)
pipe.to(device=torch_device, dtype=torch.float32)
pipe.unet.to(device="cpu")
pipe.vae.to(device="cuda")
pipe.text_encoder.to(device="cuda:0")
pipe_device = pipe.device
self.assertEqual(
self.expected_pipe_device,
pipe_device,
f"Wrong expected device. Expected {self.expected_pipe_device}. Got {pipe_device}.",
)
def test_deterministic_dtype(self):
components = self.get_dummy_components_image_generation()
pipe = StableDiffusionPipeline(**components)
pipe.to(device=torch_device, dtype=torch.float32)
pipe.unet.to(dtype=torch.float16)
pipe.vae.to(dtype=torch.float32)
pipe.text_encoder.to(dtype=torch.float64)
pipe_dtype = pipe.dtype
self.assertEqual(
self.expected_pipe_dtype,
pipe_dtype,
f"Wrong expected dtype. Expected {self.expected_pipe_dtype}. Got {pipe_dtype}.",
)
+2
View File
@@ -33,6 +33,7 @@ from diffusers.utils.testing_utils import (
numpy_cosine_similarity_distance,
require_accelerate,
require_bitsandbytes_version_greater,
require_peft_backend,
require_torch,
require_torch_gpu,
require_transformers_version_greater,
@@ -668,6 +669,7 @@ class SlowBnb4BitFluxTests(Base4bitTests):
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
self.assertTrue(max_diff < 1e-3)
@require_peft_backend
def test_lora_loading(self):
self.pipeline_4bit.load_lora_weights(
hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd"
+3
View File
@@ -10,6 +10,7 @@ from diffusers.utils.testing_utils import (
numpy_cosine_similarity_distance,
require_accelerate,
require_big_gpu_with_torch_cuda,
require_torch_cuda_compatibility,
torch_device,
)
@@ -311,6 +312,7 @@ class FluxTransformerInt8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCa
return {"weights_dtype": "int8"}
@require_torch_cuda_compatibility(8.0)
class FluxTransformerInt4WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
expected_memory_reduction = 0.55
@@ -318,6 +320,7 @@ class FluxTransformerInt4WeightsTest(FluxTransformerQuantoMixin, unittest.TestCa
return {"weights_dtype": "int4"}
@require_torch_cuda_compatibility(8.0)
class FluxTransformerInt2WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
expected_memory_reduction = 0.65
+17 -14
View File
@@ -21,7 +21,15 @@ import PIL.Image
import torch
from diffusers.image_processor import VaeImageProcessor
from diffusers.utils.remote_utils import remote_decode
from diffusers.utils.constants import (
DECODE_ENDPOINT_FLUX,
DECODE_ENDPOINT_HUNYUAN_VIDEO,
DECODE_ENDPOINT_SD_V1,
DECODE_ENDPOINT_SD_XL,
)
from diffusers.utils.remote_utils import (
remote_decode,
)
from diffusers.utils.testing_utils import (
enable_full_determinism,
slow,
@@ -33,11 +41,6 @@ from diffusers.video_processor import VideoProcessor
enable_full_determinism()
ENDPOINT_SD_V1 = "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/"
ENDPOINT_SD_XL = "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud/"
ENDPOINT_FLUX = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/"
ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/"
class RemoteAutoencoderKLMixin:
shape: Tuple[int, ...] = None
@@ -350,7 +353,7 @@ class RemoteAutoencoderKLSDv1Tests(
512,
512,
)
endpoint = ENDPOINT_SD_V1
endpoint = DECODE_ENDPOINT_SD_V1
dtype = torch.float16
scaling_factor = 0.18215
shift_factor = None
@@ -374,7 +377,7 @@ class RemoteAutoencoderKLSDXLTests(
1024,
1024,
)
endpoint = ENDPOINT_SD_XL
endpoint = DECODE_ENDPOINT_SD_XL
dtype = torch.float16
scaling_factor = 0.13025
shift_factor = None
@@ -398,7 +401,7 @@ class RemoteAutoencoderKLFluxTests(
1024,
1024,
)
endpoint = ENDPOINT_FLUX
endpoint = DECODE_ENDPOINT_FLUX
dtype = torch.bfloat16
scaling_factor = 0.3611
shift_factor = 0.1159
@@ -425,7 +428,7 @@ class RemoteAutoencoderKLFluxPackedTests(
)
height = 1024
width = 1024
endpoint = ENDPOINT_FLUX
endpoint = DECODE_ENDPOINT_FLUX
dtype = torch.bfloat16
scaling_factor = 0.3611
shift_factor = 0.1159
@@ -453,7 +456,7 @@ class RemoteAutoencoderKLHunyuanVideoTests(
320,
512,
)
endpoint = ENDPOINT_HUNYUAN_VIDEO
endpoint = DECODE_ENDPOINT_HUNYUAN_VIDEO
dtype = torch.float16
scaling_factor = 0.476986
processor_cls = VideoProcessor
@@ -504,7 +507,7 @@ class RemoteAutoencoderKLSDv1SlowTests(
RemoteAutoencoderKLSlowTestMixin,
unittest.TestCase,
):
endpoint = ENDPOINT_SD_V1
endpoint = DECODE_ENDPOINT_SD_V1
dtype = torch.float16
scaling_factor = 0.18215
shift_factor = None
@@ -515,7 +518,7 @@ class RemoteAutoencoderKLSDXLSlowTests(
RemoteAutoencoderKLSlowTestMixin,
unittest.TestCase,
):
endpoint = ENDPOINT_SD_XL
endpoint = DECODE_ENDPOINT_SD_XL
dtype = torch.float16
scaling_factor = 0.13025
shift_factor = None
@@ -527,7 +530,7 @@ class RemoteAutoencoderKLFluxSlowTests(
unittest.TestCase,
):
channels = 16
endpoint = ENDPOINT_FLUX
endpoint = DECODE_ENDPOINT_FLUX
dtype = torch.bfloat16
scaling_factor = 0.3611
shift_factor = 0.1159
+224
View File
@@ -0,0 +1,224 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import PIL.Image
import torch
from diffusers.utils import load_image
from diffusers.utils.constants import (
DECODE_ENDPOINT_FLUX,
DECODE_ENDPOINT_SD_V1,
DECODE_ENDPOINT_SD_XL,
ENCODE_ENDPOINT_FLUX,
ENCODE_ENDPOINT_SD_V1,
ENCODE_ENDPOINT_SD_XL,
)
from diffusers.utils.remote_utils import (
remote_decode,
remote_encode,
)
from diffusers.utils.testing_utils import (
enable_full_determinism,
slow,
)
enable_full_determinism()
IMAGE = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg?download=true"
class RemoteAutoencoderKLEncodeMixin:
channels: int = None
endpoint: str = None
decode_endpoint: str = None
dtype: torch.dtype = None
scaling_factor: float = None
shift_factor: float = None
image: PIL.Image.Image = None
def get_dummy_inputs(self):
if self.image is None:
self.image = load_image(IMAGE)
inputs = {
"endpoint": self.endpoint,
"image": self.image,
"scaling_factor": self.scaling_factor,
"shift_factor": self.shift_factor,
}
return inputs
def test_image_input(self):
inputs = self.get_dummy_inputs()
height, width = inputs["image"].height, inputs["image"].width
output = remote_encode(**inputs)
self.assertEqual(list(output.shape), [1, self.channels, height // 8, width // 8])
decoded = remote_decode(
tensor=output,
endpoint=self.decode_endpoint,
scaling_factor=self.scaling_factor,
shift_factor=self.shift_factor,
image_format="png",
)
self.assertEqual(decoded.height, height)
self.assertEqual(decoded.width, width)
# image_slice = torch.from_numpy(np.array(inputs["image"])[0, -3:, -3:].flatten())
# decoded_slice = torch.from_numpy(np.array(decoded)[0, -3:, -3:].flatten())
# TODO: how to test this? encode->decode is lossy. expected slice of encoded latent?
class RemoteAutoencoderKLSDv1Tests(
RemoteAutoencoderKLEncodeMixin,
unittest.TestCase,
):
channels = 4
endpoint = ENCODE_ENDPOINT_SD_V1
decode_endpoint = DECODE_ENDPOINT_SD_V1
dtype = torch.float16
scaling_factor = 0.18215
shift_factor = None
class RemoteAutoencoderKLSDXLTests(
RemoteAutoencoderKLEncodeMixin,
unittest.TestCase,
):
channels = 4
endpoint = ENCODE_ENDPOINT_SD_XL
decode_endpoint = DECODE_ENDPOINT_SD_XL
dtype = torch.float16
scaling_factor = 0.13025
shift_factor = None
class RemoteAutoencoderKLFluxTests(
RemoteAutoencoderKLEncodeMixin,
unittest.TestCase,
):
channels = 16
endpoint = ENCODE_ENDPOINT_FLUX
decode_endpoint = DECODE_ENDPOINT_FLUX
dtype = torch.bfloat16
scaling_factor = 0.3611
shift_factor = 0.1159
class RemoteAutoencoderKLEncodeSlowTestMixin:
channels: int = 4
endpoint: str = None
decode_endpoint: str = None
dtype: torch.dtype = None
scaling_factor: float = None
shift_factor: float = None
image: PIL.Image.Image = None
def get_dummy_inputs(self):
if self.image is None:
self.image = load_image(IMAGE)
inputs = {
"endpoint": self.endpoint,
"image": self.image,
"scaling_factor": self.scaling_factor,
"shift_factor": self.shift_factor,
}
return inputs
def test_multi_res(self):
inputs = self.get_dummy_inputs()
for height in {
320,
512,
640,
704,
896,
1024,
1208,
1384,
1536,
1608,
1864,
2048,
}:
for width in {
320,
512,
640,
704,
896,
1024,
1208,
1384,
1536,
1608,
1864,
2048,
}:
inputs["image"] = inputs["image"].resize(
(
width,
height,
)
)
output = remote_encode(**inputs)
self.assertEqual(list(output.shape), [1, self.channels, height // 8, width // 8])
decoded = remote_decode(
tensor=output,
endpoint=self.decode_endpoint,
scaling_factor=self.scaling_factor,
shift_factor=self.shift_factor,
image_format="png",
)
self.assertEqual(decoded.height, height)
self.assertEqual(decoded.width, width)
decoded.save(f"test_multi_res_{height}_{width}.png")
@slow
class RemoteAutoencoderKLSDv1SlowTests(
RemoteAutoencoderKLEncodeSlowTestMixin,
unittest.TestCase,
):
endpoint = ENCODE_ENDPOINT_SD_V1
decode_endpoint = DECODE_ENDPOINT_SD_V1
dtype = torch.float16
scaling_factor = 0.18215
shift_factor = None
@slow
class RemoteAutoencoderKLSDXLSlowTests(
RemoteAutoencoderKLEncodeSlowTestMixin,
unittest.TestCase,
):
endpoint = ENCODE_ENDPOINT_SD_XL
decode_endpoint = DECODE_ENDPOINT_SD_XL
dtype = torch.float16
scaling_factor = 0.13025
shift_factor = None
@slow
class RemoteAutoencoderKLFluxSlowTests(
RemoteAutoencoderKLEncodeSlowTestMixin,
unittest.TestCase,
):
channels = 16
endpoint = ENCODE_ENDPOINT_FLUX
decode_endpoint = DECODE_ENDPOINT_FLUX
dtype = torch.bfloat16
scaling_factor = 0.3611
shift_factor = 0.1159