Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d9915a7d65 | |||
| b7a795dbeb | |||
| 438905d63e | |||
| 904f24de5a | |||
| e123bbcbc4 | |||
| b3fa8c695d | |||
| 720be2bac5 | |||
| e74b782aac | |||
| d6392b4b49 | |||
| 1475026960 | |||
| 878eb4ce35 |
@@ -38,7 +38,6 @@ jobs:
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install pandas peft
|
||||
python -m uv pip uninstall transformers && python -m uv pip install transformers==4.48.0
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
|
||||
@@ -414,16 +414,12 @@ jobs:
|
||||
config:
|
||||
- backend: "bitsandbytes"
|
||||
test_location: "bnb"
|
||||
additional_deps: ["peft"]
|
||||
- backend: "gguf"
|
||||
test_location: "gguf"
|
||||
additional_deps: []
|
||||
- backend: "torchao"
|
||||
test_location: "torchao"
|
||||
additional_deps: []
|
||||
- backend: "optimum_quanto"
|
||||
test_location: "quanto"
|
||||
additional_deps: []
|
||||
runs-on:
|
||||
group: aws-g6e-xlarge-plus
|
||||
container:
|
||||
@@ -441,9 +437,6 @@ jobs:
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install -U ${{ matrix.config.backend }}
|
||||
if [ "${{ join(matrix.config.additional_deps, ' ') }}" != "" ]; then
|
||||
python -m uv pip install ${{ join(matrix.config.additional_deps, ' ') }}
|
||||
fi
|
||||
python -m uv pip install pytest-reportlog
|
||||
- name: Environment
|
||||
run: |
|
||||
|
||||
@@ -81,8 +81,6 @@
|
||||
title: Overview
|
||||
- local: hybrid_inference/vae_decode
|
||||
title: VAE Decode
|
||||
- local: hybrid_inference/vae_encode
|
||||
title: VAE Encode
|
||||
- local: hybrid_inference/api_reference
|
||||
title: API Reference
|
||||
title: Hybrid Inference
|
||||
|
||||
@@ -58,10 +58,10 @@ Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fa
|
||||
First, load the pipeline:
|
||||
|
||||
```python
|
||||
from diffusers import LuminaPipeline
|
||||
from diffusers import LuminaText2ImgPipeline
|
||||
import torch
|
||||
|
||||
pipeline = LuminaPipeline.from_pretrained(
|
||||
pipeline = LuminaText2ImgPipeline.from_pretrained(
|
||||
"Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
```
|
||||
@@ -86,11 +86,11 @@ image = pipeline(prompt="Upper body of a young woman in a Victorian-era outfit w
|
||||
|
||||
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
|
||||
|
||||
Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`LuminaPipeline`] for inference with bitsandbytes.
|
||||
Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`LuminaText2ImgPipeline`] for inference with bitsandbytes.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, Transformer2DModel, LuminaPipeline
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, Transformer2DModel, LuminaText2ImgPipeline
|
||||
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel
|
||||
|
||||
quant_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
@@ -109,7 +109,7 @@ transformer_8bit = Transformer2DModel.from_pretrained(
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
pipeline = LuminaPipeline.from_pretrained(
|
||||
pipeline = LuminaText2ImgPipeline.from_pretrained(
|
||||
"Alpha-VLLM/Lumina-Next-SFT-diffusers",
|
||||
text_encoder=text_encoder_8bit,
|
||||
transformer=transformer_8bit,
|
||||
@@ -122,9 +122,9 @@ image = pipeline(prompt).images[0]
|
||||
image.save("lumina.png")
|
||||
```
|
||||
|
||||
## LuminaPipeline
|
||||
## LuminaText2ImgPipeline
|
||||
|
||||
[[autodoc]] LuminaPipeline
|
||||
[[autodoc]] LuminaText2ImgPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
|
||||
@@ -36,14 +36,14 @@ Single file loading for Lumina Image 2.0 is available for the `Lumina2Transforme
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import Lumina2Transformer2DModel, Lumina2Pipeline
|
||||
from diffusers import Lumina2Transformer2DModel, Lumina2Text2ImgPipeline
|
||||
|
||||
ckpt_path = "https://huggingface.co/Alpha-VLLM/Lumina-Image-2.0/blob/main/consolidated.00-of-01.pth"
|
||||
transformer = Lumina2Transformer2DModel.from_single_file(
|
||||
ckpt_path, torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
pipe = Lumina2Pipeline.from_pretrained(
|
||||
pipe = Lumina2Text2ImgPipeline.from_pretrained(
|
||||
"Alpha-VLLM/Lumina-Image-2.0", transformer=transformer, torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
@@ -60,7 +60,7 @@ image.save("lumina-single-file.png")
|
||||
GGUF Quantized checkpoints for the `Lumina2Transformer2DModel` can be loaded via `from_single_file` with the `GGUFQuantizationConfig`
|
||||
|
||||
```python
|
||||
from diffusers import Lumina2Transformer2DModel, Lumina2Pipeline, GGUFQuantizationConfig
|
||||
from diffusers import Lumina2Transformer2DModel, Lumina2Text2ImgPipeline, GGUFQuantizationConfig
|
||||
|
||||
ckpt_path = "https://huggingface.co/calcuis/lumina-gguf/blob/main/lumina2-q4_0.gguf"
|
||||
transformer = Lumina2Transformer2DModel.from_single_file(
|
||||
@@ -69,7 +69,7 @@ transformer = Lumina2Transformer2DModel.from_single_file(
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
pipe = Lumina2Pipeline.from_pretrained(
|
||||
pipe = Lumina2Text2ImgPipeline.from_pretrained(
|
||||
"Alpha-VLLM/Lumina-Image-2.0", transformer=transformer, torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
@@ -80,8 +80,8 @@ image = pipe(
|
||||
image.save("lumina-gguf.png")
|
||||
```
|
||||
|
||||
## Lumina2Pipeline
|
||||
## Lumina2Text2ImgPipeline
|
||||
|
||||
[[autodoc]] Lumina2Pipeline
|
||||
[[autodoc]] Lumina2Text2ImgPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
@@ -14,365 +14,22 @@
|
||||
|
||||
# Wan
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
</div>
|
||||
|
||||
[Wan 2.1](https://github.com/Wan-Video/Wan2.1) by the Alibaba Wan Team.
|
||||
|
||||
<!-- TODO(aryan): update abstract once paper is out -->
|
||||
|
||||
## Generating Videos with Wan 2.1
|
||||
|
||||
We will first need to install some addtional dependencies.
|
||||
|
||||
```shell
|
||||
pip install -u ftfy imageio-ffmpeg imageio
|
||||
```
|
||||
|
||||
### Text to Video Generation
|
||||
|
||||
The following example requires 11GB VRAM to run and uses the smaller `Wan-AI/Wan2.1-T2V-1.3B-Diffusers` model. You can switch it out
|
||||
for the larger `Wan2.1-I2V-14B-720P-Diffusers` or `Wan-AI/Wan2.1-I2V-14B-480P-Diffusers` if you have at least 35GB VRAM available.
|
||||
|
||||
```python
|
||||
from diffusers import WanPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
# Available models: Wan-AI/Wan2.1-I2V-14B-720P-Diffusers or Wan-AI/Wan2.1-I2V-14B-480P-Diffusers
|
||||
model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
|
||||
|
||||
pipe = WanPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
|
||||
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
||||
num_frames = 33
|
||||
|
||||
frames = pipe(prompt=prompt, negative_prompt=negative_prompt, num_frames=num_frames).frames[0]
|
||||
export_to_video(frames, "wan-t2v.mp4", fps=16)
|
||||
```
|
||||
|
||||
<Tip>
|
||||
You can improve the quality of the generated video by running the decoding step in full precision.
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
```python
|
||||
from diffusers import WanPipeline, AutoencoderKLWan
|
||||
from diffusers.utils import export_to_video
|
||||
Recommendations for inference:
|
||||
- VAE in `torch.float32` for better decoding quality.
|
||||
- `num_frames` should be of the form `4 * k + 1`, for example `49` or `81`.
|
||||
- For smaller resolution videos, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution videos, try higher values (between `7.0` and `12.0`). The default value is `3.0` for Wan.
|
||||
|
||||
model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
|
||||
|
||||
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
|
||||
|
||||
# replace this with pipe.to("cuda") if you have sufficient VRAM
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
|
||||
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
||||
num_frames = 33
|
||||
|
||||
frames = pipe(prompt=prompt, num_frames=num_frames).frames[0]
|
||||
export_to_video(frames, "wan-t2v.mp4", fps=16)
|
||||
```
|
||||
|
||||
### Image to Video Generation
|
||||
|
||||
The Image to Video pipeline requires loading the `AutoencoderKLWan` and the `CLIPVisionModel` components in full precision. The following example will need at least
|
||||
35GB of VRAM to run.
|
||||
|
||||
```python
|
||||
import torch
|
||||
import numpy as np
|
||||
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
|
||||
from diffusers.utils import export_to_video, load_image
|
||||
from transformers import CLIPVisionModel
|
||||
|
||||
# Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
|
||||
model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
|
||||
image_encoder = CLIPVisionModel.from_pretrained(
|
||||
model_id, subfolder="image_encoder", torch_dtype=torch.float32
|
||||
)
|
||||
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
pipe = WanImageToVideoPipeline.from_pretrained(
|
||||
model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
# replace this with pipe.to("cuda") if you have sufficient VRAM
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
|
||||
)
|
||||
|
||||
max_area = 480 * 832
|
||||
aspect_ratio = image.height / image.width
|
||||
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
|
||||
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
|
||||
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
|
||||
image = image.resize((width, height))
|
||||
|
||||
prompt = (
|
||||
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
|
||||
"the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
|
||||
)
|
||||
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
||||
|
||||
num_frames = 33
|
||||
|
||||
output = pipe(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
guidance_scale=5.0,
|
||||
).frames[0]
|
||||
export_to_video(output, "wan-i2v.mp4", fps=16)
|
||||
```
|
||||
|
||||
## Memory Optimizations for Wan 2.1
|
||||
|
||||
Base inference with the large 14B Wan 2.1 models can take up to 35GB of VRAM when generating videos at 720p resolution. We'll outline a few memory optimizations we can apply to reduce the VRAM required to run the model.
|
||||
|
||||
We'll use `Wan-AI/Wan2.1-I2V-14B-720P-Diffusers` model in these examples to demonstrate the memory savings, but the techniques are applicable to all model checkpoints.
|
||||
|
||||
### Group Offloading the Transformer and UMT5 Text Encoder
|
||||
|
||||
Find more information about group offloading [here](../optimization/memory.md)
|
||||
|
||||
#### Block Level Group Offloading
|
||||
|
||||
We can reduce our VRAM requirements by applying group offloading to the larger model components of the pipeline; the `WanTransformer3DModel` and `UMT5EncoderModel`. Group offloading will break up the individual modules of a model and offload/onload them onto your GPU as needed during inference. In this example, we'll apply `block_level` offloading, which will group the modules in a model into blocks of size `num_blocks_per_group` and offload/onload them to GPU. Moving to between CPU and GPU does add latency to the inference process. You can trade off between latency and memory savings by increasing or decreasing the `num_blocks_per_group`.
|
||||
|
||||
The following example will now only require 14GB of VRAM to run, but will take approximately 30 minutes to generate a video.
|
||||
|
||||
```python
|
||||
import torch
|
||||
import numpy as np
|
||||
from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanImageToVideoPipeline
|
||||
from diffusers.hooks.group_offloading import apply_group_offloading
|
||||
from diffusers.utils import export_to_video, load_image
|
||||
from transformers import UMT5EncoderModel, CLIPVisionModel
|
||||
|
||||
# Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
|
||||
model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
|
||||
image_encoder = CLIPVisionModel.from_pretrained(
|
||||
model_id, subfolder="image_encoder", torch_dtype=torch.float32
|
||||
)
|
||||
|
||||
text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
|
||||
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
|
||||
onload_device = torch.device("cuda")
|
||||
offload_device = torch.device("cpu")
|
||||
|
||||
apply_group_offloading(text_encoder,
|
||||
onload_device=onload_device,
|
||||
offload_device=offload_device,
|
||||
offload_type="block_level",
|
||||
num_blocks_per_group=4
|
||||
)
|
||||
|
||||
transformer.enable_group_offload(
|
||||
onload_device=onload_device,
|
||||
offload_device=offload_device,
|
||||
offload_type="block_level",
|
||||
num_blocks_per_group=4,
|
||||
)
|
||||
pipe = WanImageToVideoPipeline.from_pretrained(
|
||||
model_id,
|
||||
vae=vae,
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
image_encoder=image_encoder,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
# Since we've offloaded the larger models alrady, we can move the rest of the model components to GPU
|
||||
pipe.to("cuda")
|
||||
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
|
||||
)
|
||||
|
||||
max_area = 720 * 832
|
||||
aspect_ratio = image.height / image.width
|
||||
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
|
||||
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
|
||||
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
|
||||
image = image.resize((width, height))
|
||||
|
||||
prompt = (
|
||||
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
|
||||
"the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
|
||||
)
|
||||
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
||||
|
||||
num_frames = 33
|
||||
|
||||
output = pipe(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
guidance_scale=5.0,
|
||||
).frames[0]
|
||||
|
||||
export_to_video(output, "wan-i2v.mp4", fps=16)
|
||||
```
|
||||
|
||||
#### Block Level Group Offloading with CUDA Streams
|
||||
|
||||
We can speed up group offloading inference, by enabling the use of [CUDA streams](https://pytorch.org/docs/stable/generated/torch.cuda.Stream.html). However, using CUDA streams requires moving the model parameters into pinned memory. This allocation is handled by Pytorch under the hood, and can result in a significant spike in CPU RAM usage. Please consider this option if your CPU RAM is atleast 2X the size of the model you are group offloading.
|
||||
|
||||
In the following example we will use CUDA streams when group offloading the `WanTransformer3DModel`. When testing on an A100, this example will require 14GB of VRAM, 52GB of CPU RAM, but will generate a video in approximately 9 minutes.
|
||||
|
||||
```python
|
||||
import torch
|
||||
import numpy as np
|
||||
from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanImageToVideoPipeline
|
||||
from diffusers.hooks.group_offloading import apply_group_offloading
|
||||
from diffusers.utils import export_to_video, load_image
|
||||
from transformers import UMT5EncoderModel, CLIPVisionModel
|
||||
|
||||
# Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
|
||||
model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
|
||||
image_encoder = CLIPVisionModel.from_pretrained(
|
||||
model_id, subfolder="image_encoder", torch_dtype=torch.float32
|
||||
)
|
||||
|
||||
text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
|
||||
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
|
||||
onload_device = torch.device("cuda")
|
||||
offload_device = torch.device("cpu")
|
||||
|
||||
apply_group_offloading(text_encoder,
|
||||
onload_device=onload_device,
|
||||
offload_device=offload_device,
|
||||
offload_type="block_level",
|
||||
num_blocks_per_group=4
|
||||
)
|
||||
|
||||
transformer.enable_group_offload(
|
||||
onload_device=onload_device,
|
||||
offload_device=offload_device,
|
||||
offload_type="leaf_level",
|
||||
use_stream=True
|
||||
)
|
||||
pipe = WanImageToVideoPipeline.from_pretrained(
|
||||
model_id,
|
||||
vae=vae,
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
image_encoder=image_encoder,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
# Since we've offloaded the larger models alrady, we can move the rest of the model components to GPU
|
||||
pipe.to("cuda")
|
||||
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
|
||||
)
|
||||
|
||||
max_area = 720 * 832
|
||||
aspect_ratio = image.height / image.width
|
||||
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
|
||||
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
|
||||
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
|
||||
image = image.resize((width, height))
|
||||
|
||||
prompt = (
|
||||
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
|
||||
"the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
|
||||
)
|
||||
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
||||
|
||||
num_frames = 33
|
||||
|
||||
output = pipe(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
guidance_scale=5.0,
|
||||
).frames[0]
|
||||
|
||||
export_to_video(output, "wan-i2v.mp4", fps=16)
|
||||
```
|
||||
|
||||
### Applying Layerwise Casting to the Transformer
|
||||
|
||||
Find more information about layerwise casting [here](../optimization/memory.md)
|
||||
|
||||
In this example, we will model offloading with layerwise casting. Layerwise casting will downcast each layer's weights to `torch.float8_e4m3fn`, temporarily upcast to `torch.bfloat16` during the forward pass of the layer, then revert to `torch.float8_e4m3fn` afterward. This approach reduces memory requirements by approximately 50% while introducing a minor quality reduction in the generated video due to the precision trade-off.
|
||||
|
||||
This example will require 20GB of VRAM.
|
||||
|
||||
```python
|
||||
import torch
|
||||
import numpy as np
|
||||
from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanImageToVideoPipeline
|
||||
from diffusers.hooks.group_offloading import apply_group_offloading
|
||||
from diffusers.utils import export_to_video, load_image
|
||||
from transformers import UMT5EncoderModel, CLIPVisionMode
|
||||
|
||||
model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
|
||||
image_encoder = CLIPVisionModel.from_pretrained(
|
||||
model_id, subfolder="image_encoder", torch_dtype=torch.float32
|
||||
)
|
||||
text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
|
||||
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
|
||||
transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
|
||||
|
||||
pipe = WanImageToVideoPipeline.from_pretrained(
|
||||
model_id,
|
||||
vae=vae,
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
image_encoder=image_encoder,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg")
|
||||
|
||||
max_area = 720 * 832
|
||||
aspect_ratio = image.height / image.width
|
||||
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
|
||||
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
|
||||
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
|
||||
image = image.resize((width, height))
|
||||
prompt = (
|
||||
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
|
||||
"the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
|
||||
)
|
||||
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
|
||||
num_frames = 33
|
||||
|
||||
output = pipe(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=5.0,
|
||||
).frames[0]
|
||||
export_to_video(output, "wan-i2v.mp4", fps=16)
|
||||
```
|
||||
|
||||
### Using a Custom Scheduler
|
||||
### Using a custom scheduler
|
||||
|
||||
Wan can be used with many different schedulers, each with their own benefits regarding speed and generation quality. By default, Wan uses the `UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0)` scheduler. You can use a different scheduler as follows:
|
||||
|
||||
@@ -388,10 +45,11 @@ pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", scheduler
|
||||
pipe.scheduler = <CUSTOM_SCHEDULER_HERE>
|
||||
```
|
||||
|
||||
## Using Single File Loading with Wan 2.1
|
||||
### Using single file loading with Wan
|
||||
|
||||
The `WanTransformer3DModel` and `AutoencoderKLWan` models support loading checkpoints in their original format via the `from_single_file` loading
|
||||
method.
|
||||
|
||||
The `WanTransformer3DModel` and `AutoencoderKLWan` models support loading checkpoints in their original format via the `from_single_file` loading
|
||||
method.
|
||||
|
||||
```python
|
||||
import torch
|
||||
@@ -403,11 +61,6 @@ transformer = WanTransformer3DModel.from_single_file(ckpt_path, torch_dtype=torc
|
||||
pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", transformer=transformer)
|
||||
```
|
||||
|
||||
## Recommendations for Inference:
|
||||
- Keep `AutencoderKLWan` in `torch.float32` for better decoding quality.
|
||||
- `num_frames` should satisfy the following constraint: `(num_frames - 1) % 4 == 0`
|
||||
- For smaller resolution videos, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution videos, try higher values (between `7.0` and `12.0`). The default value is `3.0` for Wan.
|
||||
|
||||
## WanPipeline
|
||||
|
||||
[[autodoc]] WanPipeline
|
||||
|
||||
@@ -3,7 +3,3 @@
|
||||
## Remote Decode
|
||||
|
||||
[[autodoc]] utils.remote_utils.remote_decode
|
||||
|
||||
## Remote Encode
|
||||
|
||||
[[autodoc]] utils.remote_utils.remote_encode
|
||||
|
||||
@@ -36,7 +36,7 @@ Hybrid Inference offers a fast and simple way to offload local generation requir
|
||||
## Available Models
|
||||
|
||||
* **VAE Decode 🖼️:** Quickly decode latent representations into high-quality images without compromising performance or workflow speed.
|
||||
* **VAE Encode 🔢:** Efficiently encode images into latent representations for generation and training.
|
||||
* **VAE Encode 🔢 (coming soon):** Efficiently encode images into latent representations for generation and training.
|
||||
* **Text Encoders 📃 (coming soon):** Compute text embeddings for your prompts quickly and accurately, ensuring a smooth and high-quality workflow.
|
||||
|
||||
---
|
||||
@@ -46,15 +46,9 @@ Hybrid Inference offers a fast and simple way to offload local generation requir
|
||||
* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
|
||||
* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.
|
||||
|
||||
## Changelog
|
||||
|
||||
- March 10 2025: Added VAE encode
|
||||
- March 2 2025: Initial release with VAE decoding
|
||||
|
||||
## Contents
|
||||
|
||||
The documentation is organized into three sections:
|
||||
The documentation is organized into two sections:
|
||||
|
||||
* **VAE Decode** Learn the basics of how to use VAE Decode with Hybrid Inference.
|
||||
* **VAE Encode** Learn the basics of how to use VAE Encode with Hybrid Inference.
|
||||
* **API Reference** Dive into task-specific settings and parameters.
|
||||
|
||||
@@ -1,183 +0,0 @@
|
||||
# Getting Started: VAE Encode with Hybrid Inference
|
||||
|
||||
VAE encode is used for training, image-to-image and image-to-video - turning into images or videos into latent representations.
|
||||
|
||||
## Memory
|
||||
|
||||
These tables demonstrate the VRAM requirements for VAE encode with SD v1 and SD XL on different GPUs.
|
||||
|
||||
For the majority of these GPUs the memory usage % dictates other models (text encoders, UNet/Transformer) must be offloaded, or tiled encoding has to be used which increases time taken and impacts quality.
|
||||
|
||||
<details><summary>SD v1.5</summary>
|
||||
|
||||
| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) |
|
||||
|:------------------------------|:-------------|-----------------:|-------------:|--------------------:|-------------------:|
|
||||
| NVIDIA GeForce RTX 4090 | 512x512 | 0.015 | 3.51901 | 0.015 | 3.51901 |
|
||||
| NVIDIA GeForce RTX 4090 | 256x256 | 0.004 | 1.3154 | 0.005 | 1.3154 |
|
||||
| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.402 | 47.1852 | 0.496 | 3.51901 |
|
||||
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.078 | 12.2658 | 0.094 | 3.51901 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.023 | 5.30105 | 0.023 | 5.30105 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.006 | 1.98152 | 0.006 | 1.98152 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 0.574 | 71.08 | 0.656 | 5.30105 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.111 | 18.4772 | 0.14 | 5.30105 |
|
||||
| NVIDIA GeForce RTX 3090 | 512x512 | 0.032 | 3.52782 | 0.032 | 3.52782 |
|
||||
| NVIDIA GeForce RTX 3090 | 256x256 | 0.01 | 1.31869 | 0.009 | 1.31869 |
|
||||
| NVIDIA GeForce RTX 3090 | 2048x2048 | 0.742 | 47.3033 | 0.954 | 3.52782 |
|
||||
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.136 | 12.2965 | 0.207 | 3.52782 |
|
||||
| NVIDIA GeForce RTX 3080 | 512x512 | 0.036 | 8.51761 | 0.036 | 8.51761 |
|
||||
| NVIDIA GeForce RTX 3080 | 256x256 | 0.01 | 3.18387 | 0.01 | 3.18387 |
|
||||
| NVIDIA GeForce RTX 3080 | 2048x2048 | 0.863 | 86.7424 | 1.191 | 8.51761 |
|
||||
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.157 | 29.6888 | 0.227 | 8.51761 |
|
||||
| NVIDIA GeForce RTX 3070 | 512x512 | 0.051 | 10.6941 | 0.051 | 10.6941 |
|
||||
| NVIDIA GeForce RTX 3070 | 256x256 | 0.015 | 3.99743 | 0.015 | 3.99743 |
|
||||
| NVIDIA GeForce RTX 3070 | 2048x2048 | 1.217 | 96.054 | 1.482 | 10.6941 |
|
||||
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.223 | 37.2751 | 0.327 | 10.6941 |
|
||||
|
||||
|
||||
</details>
|
||||
|
||||
<details><summary>SDXL</summary>
|
||||
|
||||
| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) |
|
||||
|:------------------------------|:-------------|-----------------:|----------------------:|-----------------------:|-------------------:|
|
||||
| NVIDIA GeForce RTX 4090 | 512x512 | 0.029 | 4.95707 | 0.029 | 4.95707 |
|
||||
| NVIDIA GeForce RTX 4090 | 256x256 | 0.007 | 2.29666 | 0.007 | 2.29666 |
|
||||
| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.873 | 66.3452 | 0.863 | 15.5649 |
|
||||
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.142 | 15.5479 | 0.143 | 15.5479 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.044 | 7.46735 | 0.044 | 7.46735 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.01 | 3.4597 | 0.01 | 3.4597 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 1.317 | 87.1615 | 1.291 | 23.447 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.213 | 23.4215 | 0.214 | 23.4215 |
|
||||
| NVIDIA GeForce RTX 3090 | 512x512 | 0.058 | 5.65638 | 0.058 | 5.65638 |
|
||||
| NVIDIA GeForce RTX 3090 | 256x256 | 0.016 | 2.45081 | 0.016 | 2.45081 |
|
||||
| NVIDIA GeForce RTX 3090 | 2048x2048 | 1.755 | 77.8239 | 1.614 | 18.4193 |
|
||||
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.265 | 18.4023 | 0.265 | 18.4023 |
|
||||
| NVIDIA GeForce RTX 3080 | 512x512 | 0.064 | 13.6568 | 0.064 | 13.6568 |
|
||||
| NVIDIA GeForce RTX 3080 | 256x256 | 0.018 | 5.91728 | 0.018 | 5.91728 |
|
||||
| NVIDIA GeForce RTX 3080 | 2048x2048 | OOM | OOM | 1.866 | 44.4717 |
|
||||
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.302 | 44.4308 | 0.302 | 44.4308 |
|
||||
| NVIDIA GeForce RTX 3070 | 512x512 | 0.093 | 17.1465 | 0.093 | 17.1465 |
|
||||
| NVIDIA GeForce RTX 3070 | 256x256 | 0.025 | 7.42931 | 0.026 | 7.42931 |
|
||||
| NVIDIA GeForce RTX 3070 | 2048x2048 | OOM | OOM | 2.674 | 55.8355 |
|
||||
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.443 | 55.7841 | 0.443 | 55.7841 |
|
||||
|
||||
</details>
|
||||
|
||||
## Available VAEs
|
||||
|
||||
| | **Endpoint** | **Model** |
|
||||
|:-:|:-----------:|:--------:|
|
||||
| **Stable Diffusion v1** | [https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud](https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud) | [`stabilityai/sd-vae-ft-mse`](https://hf.co/stabilityai/sd-vae-ft-mse) |
|
||||
| **Stable Diffusion XL** | [https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud](https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud) | [`madebyollin/sdxl-vae-fp16-fix`](https://hf.co/madebyollin/sdxl-vae-fp16-fix) |
|
||||
| **Flux** | [https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud](https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud) | [`black-forest-labs/FLUX.1-schnell`](https://hf.co/black-forest-labs/FLUX.1-schnell) |
|
||||
|
||||
|
||||
> [!TIP]
|
||||
> Model support can be requested [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml).
|
||||
|
||||
|
||||
## Code
|
||||
|
||||
> [!TIP]
|
||||
> Install `diffusers` from `main` to run the code: `pip install git+https://github.com/huggingface/diffusers@main`
|
||||
|
||||
|
||||
A helper method simplifies interacting with Hybrid Inference.
|
||||
|
||||
```python
|
||||
from diffusers.utils.remote_utils import remote_encode
|
||||
```
|
||||
|
||||
### Basic example
|
||||
|
||||
Let's encode an image, then decode it to demonstrate.
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"/>
|
||||
</figure>
|
||||
|
||||
<details><summary>Code</summary>
|
||||
|
||||
```python
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.remote_utils import remote_decode
|
||||
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg?download=true")
|
||||
|
||||
latent = remote_encode(
|
||||
endpoint="https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
scaling_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
)
|
||||
|
||||
decoded = remote_decode(
|
||||
endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=latent,
|
||||
scaling_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/decoded.png"/>
|
||||
</figure>
|
||||
|
||||
|
||||
### Generation
|
||||
|
||||
Now let's look at a generation example, we'll encode the image, generate then remotely decode too!
|
||||
|
||||
<details><summary>Code</summary>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import StableDiffusionImg2ImgPipeline
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.remote_utils import remote_decode, remote_encode
|
||||
|
||||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
vae=None,
|
||||
).to("cuda")
|
||||
|
||||
init_image = load_image(
|
||||
"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
||||
)
|
||||
init_image = init_image.resize((768, 512))
|
||||
|
||||
init_latent = remote_encode(
|
||||
endpoint="https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
image=init_image,
|
||||
scaling_factor=0.18215,
|
||||
)
|
||||
|
||||
prompt = "A fantasy landscape, trending on artstation"
|
||||
latent = pipe(
|
||||
prompt=prompt,
|
||||
image=init_latent,
|
||||
strength=0.75,
|
||||
output_type="latent",
|
||||
).images
|
||||
|
||||
image = remote_decode(
|
||||
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=latent,
|
||||
scaling_factor=0.18215,
|
||||
)
|
||||
image.save("fantasy_landscape.jpg")
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/fantasy_landscape.png"/>
|
||||
</figure>
|
||||
|
||||
## Integrations
|
||||
|
||||
* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
|
||||
* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.
|
||||
@@ -79,13 +79,13 @@ This command will prompt you for a token. Copy-paste yours from your [settings/t
|
||||
### Target Modules
|
||||
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
|
||||
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
|
||||
applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma separated string
|
||||
applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma seperated string
|
||||
the exact modules for LoRA training. Here are some examples of target modules you can provide:
|
||||
- for attention only layers: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0"`
|
||||
- to train the same modules as in the fal trainer: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2"`
|
||||
- to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2,norm1_context.linear, norm1.linear,norm.linear,proj_mlp,proj_out"`
|
||||
> [!NOTE]
|
||||
> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma separated string:
|
||||
> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma seperated string:
|
||||
> **single DiT blocks**: to target the ith single transformer block, add the prefix `single_transformer_blocks.i`, e.g. - `single_transformer_blocks.i.attn.to_k`
|
||||
> **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k`
|
||||
> [!NOTE]
|
||||
|
||||
@@ -378,7 +378,7 @@ def parse_args(input_args=None):
|
||||
default=None,
|
||||
help="the concept to use to initialize the new inserted tokens when training with "
|
||||
"--train_text_encoder_ti = True. By default, new tokens (<si><si+1>) are initialized with random value. "
|
||||
"Alternatively, you could specify a different word/words whose value will be used as the starting point for the new inserted tokens. "
|
||||
"Alternatively, you could specify a different word/words whos value will be used as the starting point for the new inserted tokens. "
|
||||
"--num_new_tokens_per_abstraction is ignored when initializer_concept is provided",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -662,7 +662,7 @@ def parse_args(input_args=None):
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. "
|
||||
"The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. "
|
||||
'E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only. For more examples refer to https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/README_flux.md'
|
||||
),
|
||||
)
|
||||
|
||||
@@ -662,7 +662,7 @@ def parse_args(input_args=None):
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=(
|
||||
"Whether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. "
|
||||
"Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. "
|
||||
"Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -773,7 +773,7 @@ def parse_args(input_args=None):
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=(
|
||||
"Whether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. "
|
||||
"Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. "
|
||||
"Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`"
|
||||
),
|
||||
)
|
||||
@@ -1875,7 +1875,7 @@ def main(args):
|
||||
# pack the statically computed variables appropriately here. This is so that we don't
|
||||
# have to pass them to the dataloader.
|
||||
|
||||
# if --train_text_encoder_ti we need add_special_tokens to be True for textual inversion
|
||||
# if --train_text_encoder_ti we need add_special_tokens to be True fo textual inversion
|
||||
add_special_tokens = True if args.train_text_encoder_ti else False
|
||||
|
||||
if not train_dataset.custom_instance_prompts:
|
||||
|
||||
@@ -1,201 +0,0 @@
|
||||
# Training CogView4 Control
|
||||
|
||||
This (experimental) example shows how to train Control LoRAs with [CogView4](https://huggingface.co/THUDM/CogView4-6B) by conditioning it with additional structural controls (like depth maps, poses, etc.). We provide a script for full fine-tuning, too, refer to [this section](#full-fine-tuning). To know more about CogView4 Control family, refer to the following resources:
|
||||
|
||||
To incorporate additional condition latents, we expand the input features of CogView-4 from 64 to 128. The first 64 channels correspond to the original input latents to be denoised, while the latter 64 channels correspond to control latents. This expansion happens on the `patch_embed` layer, where the combined latents are projected to the expected feature dimension of rest of the network. Inference is performed using the `CogView4ControlPipeline`.
|
||||
|
||||
> [!NOTE]
|
||||
> **Gated model**
|
||||
>
|
||||
> As the model is gated, before using it with diffusers you first need to go to the [CogView4 Hugging Face page](https://huggingface.co/THUDM/CogView4-6B), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
The example command below shows how to launch fine-tuning for pose conditions. The dataset ([`raulc0399/open_pose_controlnet`](https://huggingface.co/datasets/raulc0399/open_pose_controlnet)) being used here already has the pose conditions of the original images, so we don't have to compute them.
|
||||
|
||||
```bash
|
||||
accelerate launch train_control_lora_cogview4.py \
|
||||
--pretrained_model_name_or_path="THUDM/CogView4-6B" \
|
||||
--dataset_name="raulc0399/open_pose_controlnet" \
|
||||
--output_dir="pose-control-lora" \
|
||||
--mixed_precision="bf16" \
|
||||
--train_batch_size=1 \
|
||||
--rank=64 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--gradient_checkpointing \
|
||||
--use_8bit_adam \
|
||||
--learning_rate=1e-4 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=5000 \
|
||||
--validation_image="openpose.png" \
|
||||
--validation_prompt="A couple, 4k photo, highly detailed" \
|
||||
--offload \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
`openpose.png` comes from [here](https://huggingface.co/Adapter/t2iadapter/resolve/main/openpose.png).
|
||||
|
||||
You need to install `diffusers` from the branch of [this PR](https://github.com/huggingface/diffusers/pull/9999). When it's merged, you should install `diffusers` from the `main`.
|
||||
|
||||
The training script exposes additional CLI args that might be useful to experiment with:
|
||||
|
||||
* `use_lora_bias`: When set, additionally trains the biases of the `lora_B` layer.
|
||||
* `train_norm_layers`: When set, additionally trains the normalization scales. Takes care of saving and loading.
|
||||
* `lora_layers`: Specify the layers you want to apply LoRA to. If you specify "all-linear", all the linear layers will be LoRA-attached.
|
||||
|
||||
### Training with DeepSpeed
|
||||
|
||||
It's possible to train with [DeepSpeed](https://github.com/microsoft/DeepSpeed), specifically leveraging the Zero2 system optimization. To use it, save the following config to an YAML file (feel free to modify as needed):
|
||||
|
||||
```yaml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
gradient_accumulation_steps: 1
|
||||
gradient_clipping: 1.0
|
||||
offload_optimizer_device: cpu
|
||||
offload_param_device: cpu
|
||||
zero3_init_flag: false
|
||||
zero_stage: 2
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 1
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
```
|
||||
|
||||
And then while launching training, pass the config file:
|
||||
|
||||
```bash
|
||||
accelerate launch --config_file=CONFIG_FILE.yaml ...
|
||||
```
|
||||
|
||||
### Inference
|
||||
|
||||
The pose images in our dataset were computed using the [`controlnet_aux`](https://github.com/huggingface/controlnet_aux) library. Let's install it first:
|
||||
|
||||
```bash
|
||||
pip install controlnet_aux
|
||||
```
|
||||
|
||||
And then we are ready:
|
||||
|
||||
```py
|
||||
from controlnet_aux import OpenposeDetector
|
||||
from diffusers import CogView4ControlPipeline
|
||||
from diffusers.utils import load_image
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
pipe = CogView4ControlPipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16).to("cuda")
|
||||
pipe.load_lora_weights("...") # change this.
|
||||
|
||||
open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
|
||||
# prepare pose condition.
|
||||
url = "https://huggingface.co/Adapter/t2iadapter/resolve/main/people.jpg"
|
||||
image = load_image(url)
|
||||
image = open_pose(image, detect_resolution=512, image_resolution=1024)
|
||||
image = np.array(image)[:, :, ::-1]
|
||||
image = Image.fromarray(np.uint8(image))
|
||||
|
||||
prompt = "A couple, 4k photo, highly detailed"
|
||||
|
||||
gen_images = pipe(
|
||||
prompt=prompt,
|
||||
control_image=image,
|
||||
num_inference_steps=50,
|
||||
joint_attention_kwargs={"scale": 0.9},
|
||||
guidance_scale=25.,
|
||||
).images[0]
|
||||
gen_images.save("output.png")
|
||||
```
|
||||
|
||||
## Full fine-tuning
|
||||
|
||||
We provide a non-LoRA version of the training script `train_control_cogview4.py`. Here is an example command:
|
||||
|
||||
```bash
|
||||
accelerate launch --config_file=accelerate_ds2.yaml train_control_cogview4.py \
|
||||
--pretrained_model_name_or_path="THUDM/CogView4-6B" \
|
||||
--dataset_name="raulc0399/open_pose_controlnet" \
|
||||
--output_dir="pose-control" \
|
||||
--mixed_precision="bf16" \
|
||||
--train_batch_size=2 \
|
||||
--dataloader_num_workers=4 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--gradient_checkpointing \
|
||||
--use_8bit_adam \
|
||||
--proportion_empty_prompts=0.2 \
|
||||
--learning_rate=5e-5 \
|
||||
--adam_weight_decay=1e-4 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="cosine" \
|
||||
--lr_warmup_steps=1000 \
|
||||
--checkpointing_steps=1000 \
|
||||
--max_train_steps=10000 \
|
||||
--validation_steps=200 \
|
||||
--validation_image "2_pose_1024.jpg" "3_pose_1024.jpg" \
|
||||
--validation_prompt "two friends sitting by each other enjoying a day at the park, full hd, cinematic" "person enjoying a day at the park, full hd, cinematic" \
|
||||
--offload \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
Change the `validation_image` and `validation_prompt` as needed.
|
||||
|
||||
For inference, this time, we will run:
|
||||
|
||||
```py
|
||||
from controlnet_aux import OpenposeDetector
|
||||
from diffusers import CogView4ControlPipeline, CogView4Transformer2DModel
|
||||
from diffusers.utils import load_image
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
transformer = CogView4Transformer2DModel.from_pretrained("...") # change this.
|
||||
pipe = CogView4ControlPipeline.from_pretrained(
|
||||
"THUDM/CogView4-6B", transformer=transformer, torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
|
||||
open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
|
||||
# prepare pose condition.
|
||||
url = "https://huggingface.co/Adapter/t2iadapter/resolve/main/people.jpg"
|
||||
image = load_image(url)
|
||||
image = open_pose(image, detect_resolution=512, image_resolution=1024)
|
||||
image = np.array(image)[:, :, ::-1]
|
||||
image = Image.fromarray(np.uint8(image))
|
||||
|
||||
prompt = "A couple, 4k photo, highly detailed"
|
||||
|
||||
gen_images = pipe(
|
||||
prompt=prompt,
|
||||
control_image=image,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=25.,
|
||||
).images[0]
|
||||
gen_images.save("output.png")
|
||||
```
|
||||
|
||||
## Things to note
|
||||
|
||||
* The scripts provided in this directory are experimental and educational. This means we may have to tweak things around to get good results on a given condition. We believe this is best done with the community 🤗
|
||||
* The scripts are not memory-optimized but we offload the VAE and the text encoders to CPU when they are not used if `--offload` is specified.
|
||||
* We can extract LoRAs from the fully fine-tuned model. While we currently don't provide any utilities for that, users are welcome to refer to [this script](https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/master/control_lora_create.py) that provides a similar functionality.
|
||||
@@ -1,6 +0,0 @@
|
||||
transformers==4.47.0
|
||||
wandb
|
||||
torch
|
||||
torchvision
|
||||
accelerate==1.2.0
|
||||
peft>=0.14.0
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The DEVAIEXP Team and The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -1070,32 +1070,32 @@ class StableDiffusionXLTilingPipeline(
|
||||
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
||||
else:
|
||||
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
||||
add_time_ids = self._get_add_time_ids(
|
||||
original_size,
|
||||
crops_coords_top_left[row][col],
|
||||
target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
if negative_original_size is not None and negative_target_size is not None:
|
||||
negative_add_time_ids = self._get_add_time_ids(
|
||||
negative_original_size,
|
||||
negative_crops_coords_top_left[row][col],
|
||||
negative_target_size,
|
||||
add_time_ids = self._get_add_time_ids(
|
||||
original_size,
|
||||
crops_coords_top_left[row][col],
|
||||
target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
else:
|
||||
negative_add_time_ids = add_time_ids
|
||||
if negative_original_size is not None and negative_target_size is not None:
|
||||
negative_add_time_ids = self._get_add_time_ids(
|
||||
negative_original_size,
|
||||
negative_crops_coords_top_left[row][col],
|
||||
negative_target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
else:
|
||||
negative_add_time_ids = add_time_ids
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
||||
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
||||
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device)
|
||||
add_text_embeds = add_text_embeds.to(device)
|
||||
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.to(device)
|
||||
add_text_embeds = add_text_embeds.to(device)
|
||||
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
||||
addition_embed_type_row.append((prompt_embeds, add_text_embeds, add_time_ids))
|
||||
embeddings_and_added_time.append(addition_embed_type_row)
|
||||
|
||||
|
||||
@@ -152,7 +152,9 @@ def log_validation(
|
||||
validation_prompt = log["validation_prompt"]
|
||||
validation_image = log["validation_image"]
|
||||
|
||||
formatted_images = [np.asarray(validation_image)]
|
||||
formatted_images = []
|
||||
|
||||
formatted_images.append(np.asarray(validation_image))
|
||||
|
||||
for image in images:
|
||||
formatted_images.append(np.asarray(image))
|
||||
|
||||
@@ -166,7 +166,9 @@ def log_validation(
|
||||
validation_prompt = log["validation_prompt"]
|
||||
validation_image = log["validation_image"]
|
||||
|
||||
formatted_images = [np.asarray(validation_image)]
|
||||
formatted_images = []
|
||||
|
||||
formatted_images.append(np.asarray(validation_image))
|
||||
|
||||
for image in images:
|
||||
formatted_images.append(np.asarray(image))
|
||||
|
||||
@@ -1283,8 +1283,8 @@ def main(args):
|
||||
noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
prompt_embeds = batch["prompt_embeds"].to(dtype=weight_dtype)
|
||||
pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(dtype=weight_dtype)
|
||||
prompt_embeds = batch["prompt_embeds"]
|
||||
pooled_prompt_embeds = batch["pooled_prompt_embeds"]
|
||||
|
||||
# controlnet(s) inference
|
||||
controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
|
||||
|
||||
@@ -157,7 +157,9 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
|
||||
validation_prompt = log["validation_prompt"]
|
||||
validation_image = log["validation_image"]
|
||||
|
||||
formatted_images = [np.asarray(validation_image)]
|
||||
formatted_images = []
|
||||
|
||||
formatted_images.append(np.asarray(validation_image))
|
||||
|
||||
for image in images:
|
||||
formatted_images.append(np.asarray(image))
|
||||
|
||||
@@ -381,7 +381,9 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
|
||||
validation_prompt = log["validation_prompt"]
|
||||
validation_image = log["validation_image"]
|
||||
|
||||
formatted_images = [np.asarray(validation_image)]
|
||||
formatted_images = []
|
||||
|
||||
formatted_images.append(np.asarray(validation_image))
|
||||
|
||||
for image in images:
|
||||
formatted_images.append(np.asarray(image))
|
||||
|
||||
@@ -164,7 +164,9 @@ def log_validation(
|
||||
validation_prompt = log["validation_prompt"]
|
||||
validation_image = log["validation_image"]
|
||||
|
||||
formatted_images = [np.asarray(validation_image)]
|
||||
formatted_images = []
|
||||
|
||||
formatted_images.append(np.asarray(validation_image))
|
||||
|
||||
for image in images:
|
||||
formatted_images.append(np.asarray(image))
|
||||
|
||||
@@ -50,116 +50,51 @@ python flux_inference.py
|
||||
|
||||
The script loads the text encoders onto the CPU and the Flux transformer and VAE models onto the TPU. The first time the script runs, the compilation time is longer, while the cache stores the compiled programs. On subsequent runs, compilation is much faster and the subsequent passes being the fastest.
|
||||
|
||||
On a Trillium v6e-4, you should expect ~6 sec / 4 images or 1.5 sec / image (as devices run generation in parallel):
|
||||
On a Trillium v6e-4, you should expect ~9 sec / 4 images or 2.25 sec / image (as devices run generation in parallel):
|
||||
|
||||
```bash
|
||||
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
|
||||
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 7.06it/s]
|
||||
Loading pipeline components...: 60%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 3/5 [00:00<00:00, 6.80it/s]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
|
||||
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 6.28it/s]
|
||||
2025-03-14 21:17:53 [info ] loading flux from black-forest-labs/FLUX.1-dev
|
||||
2025-03-14 21:17:53 [info ] loading flux from black-forest-labs/FLUX.1-dev
|
||||
Loading pipeline components...: 0%| | 0/3 [00:00<?, ?it/s]2025-03-14 21:17:53 [info ] loading flux from black-forest-labs/FLUX.1-dev
|
||||
2025-03-14 21:17:53 [info ] loading flux from black-forest-labs/FLUX.1-dev
|
||||
Loading pipeline components...: 0%| | 0/3 [00:00<?, ?it/s]2025-03-14 21:17:54 [info ] loading flux from black-forest-labs/FLUX.1-dev
|
||||
2025-03-14 21:17:54 [info ] loading flux from black-forest-labs/FLUX.1-dev
|
||||
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.66it/s]
|
||||
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 4.48it/s]
|
||||
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.32it/s]
|
||||
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.69it/s]
|
||||
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.74it/s]
|
||||
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.10it/s]
|
||||
2025-03-14 21:17:56 [info ] loading flux from black-forest-labs/FLUX.1-dev
|
||||
Loading pipeline components...: 0%| | 0/3 [00:00<?, ?it/s]2025-03-14 21:17:56 [info ] loading flux from black-forest-labs/FLUX.1-dev
|
||||
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.55it/s]
|
||||
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:02<00:00, 1.46it/s]
|
||||
2025-03-14 21:18:34 [info ] starting compilation run...
|
||||
2025-03-14 21:18:37 [info ] starting compilation run...
|
||||
2025-03-14 21:18:38 [info ] starting compilation run...
|
||||
2025-03-14 21:18:39 [info ] starting compilation run...
|
||||
2025-03-14 21:18:41 [info ] starting compilation run...
|
||||
2025-03-14 21:18:41 [info ] starting compilation run...
|
||||
2025-03-14 21:18:42 [info ] starting compilation run...
|
||||
2025-03-14 21:18:43 [info ] starting compilation run...
|
||||
82%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 23/28 [13:35<03:04, 36.80s/it]2025-03-14 21:33:42.057559: W torch_xla/csrc/runtime/pjrt_computation_client.cc:667] Failed to deserialize executable: INTERNAL: TfrtTpuExecutable proto deserialization failed while parsing core program!
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:27<00:00, 35.28s/it]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:27<00:00, 35.26s/it]
|
||||
2025-03-14 21:36:38 [info ] compilation took 1079.3314765350078 sec.
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:12<00:00, 34.73s/it]
|
||||
2025-03-14 21:36:38 [info ] starting inference run...
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:12<00:00, 34.73s/it]
|
||||
2025-03-14 21:36:38 [info ] compilation took 1081.89390801001 sec.
|
||||
2025-03-14 21:36:39 [info ] starting inference run...
|
||||
2025-03-14 21:36:39 [info ] compilation took 1077.1543154849933 sec.
|
||||
2025-03-14 21:36:39 [info ] compilation took 1075.7239800530078 sec.
|
||||
2025-03-14 21:36:39 [info ] starting inference run...
|
||||
2025-03-14 21:36:40 [info ] starting inference run...
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:22<00:00, 35.10s/it]
|
||||
2025-03-14 21:36:50 [info ] compilation took 1088.1632604240003 sec.
|
||||
2025-03-14 21:36:50 [info ] starting inference run...
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:28<00:00, 35.32s/it]
|
||||
2025-03-14 21:36:55 [info ] compilation took 1096.8027802760043 sec.
|
||||
2025-03-14 21:36:56 [info ] starting inference run...
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:59<00:00, 36.40s/it]
|
||||
2025-03-14 21:37:08 [info ] compilation took 1113.8591305939917 sec.
|
||||
2025-03-14 21:37:08 [info ] starting inference run...
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:55<00:00, 36.26s/it]
|
||||
2025-03-14 21:37:22 [info ] compilation took 1120.5590810020076 sec.
|
||||
2025-03-14 21:37:22 [info ] starting inference run...
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 5.00it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:09<00:00, 2.98it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:06<00:00, 4.08it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:09<00:00, 2.82it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:08<00:00, 3.34it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:06<00:00, 4.22it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:06<00:00, 4.09it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:11<00:00, 2.41it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:06<00:00, 4.50it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 5.10it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 5.27it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 4.80it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 5.39it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 5.39it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.67it/s]
|
||||
29%|█████████████████████████████████████████████████████████████████████████████▍ | 8/28 [00:01<00:03, 6.08it/s]/home/jfacevedo_google_com/diffusers/src/diffusers/image_processor.py:147: RuntimeWarning: invalid value encountered in cast
|
||||
images = (images * 255).round().astype("uint8")
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.82it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.93it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.02it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.02it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.98it/s]
|
||||
71%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 20/28 [00:03<00:01, 6.03it/s]2025-03-14 21:38:32 [info ] inference time: 5.962021178987925
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.89it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.09it/s]
|
||||
2025-03-14 21:38:32 [info ] avg. inference over 5 iterations took 7.2685392687970305 sec.
|
||||
2025-03-14 21:38:32 [info ] avg. inference over 5 iterations took 7.402720856998348 sec.
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.01it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.89it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.96it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.06it/s]
|
||||
71%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 20/28 [00:03<00:01, 6.01it/s]2025-03-14 21:38:38 [info ] inference time: 5.950578948002658
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.87it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.09it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.00it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.86it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.99it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.05it/s]
|
||||
2025-03-14 21:38:43 [info ] avg. inference over 5 iterations took 6.763298449796276 sec.
|
||||
71%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 20/28 [00:03<00:01, 6.04it/s]2025-03-14 21:38:44 [info ] inference time: 5.949129879008979
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.92it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.10it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.02it/s]
|
||||
39%|██████████████████████████████████████████████████████████████████████████████████████████████████████████ | 11/28 [00:01<00:02, 5.98it/s]2025-03-14 21:38:46 [info ] avg. inference over 5 iterations took 7.221068455604836 sec.
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.96it/s]
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.08it/s]
|
||||
93%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 26/28 [00:04<00:00, 5.92it/s]2025-03-14 21:38:50 [info ] inference time: 5.954778069004533
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.90it/s]
|
||||
11%|█████████████████████████████ | 3/28 [00:00<00:04, 6.03it/s]2025-03-14 21:38:50 [info ] avg. inference over 5 iterations took 6.05970350120042 sec.
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.02it/s]
|
||||
32%|███████████████████████████████████████████████████████████████████████████████████████ | 9/28 [00:01<00:03, 5.99it/s]2025-03-14 21:38:51 [info ] avg. inference over 5 iterations took 6.018543455796316 sec.
|
||||
54%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 15/28 [00:02<00:02, 6.00it/s]2025-03-14 21:38:52 [info ] avg. inference over 5 iterations took 5.9609976705978625 sec.
|
||||
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.97it/s]
|
||||
2025-03-14 21:38:56 [info ] inference time: 5.944058528999449
|
||||
2025-03-14 21:38:56 [info ] avg. inference over 5 iterations took 5.952113320800708 sec.
|
||||
2025-03-14 21:38:56 [info ] saved metric information as /tmp/metrics_report.txt
|
||||
Loading checkpoint shards: 100%|███████████████████████████████| 2/2 [00:00<00:00, 7.01it/s]
|
||||
Loading pipeline components...: 40%|██████████▍ | 2/5 [00:00<00:00, 3.78it/s]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
|
||||
Loading pipeline components...: 100%|██████████████████████████| 5/5 [00:00<00:00, 6.72it/s]
|
||||
2025-01-10 00:51:25 [info ] loading flux from black-forest-labs/FLUX.1-dev
|
||||
2025-01-10 00:51:25 [info ] loading flux from black-forest-labs/FLUX.1-dev
|
||||
2025-01-10 00:51:26 [info ] loading flux from black-forest-labs/FLUX.1-dev
|
||||
2025-01-10 00:51:26 [info ] loading flux from black-forest-labs/FLUX.1-dev
|
||||
Loading pipeline components...: 100%|██████████████████████████| 3/3 [00:00<00:00, 4.29it/s]
|
||||
Loading pipeline components...: 100%|██████████████████████████| 3/3 [00:00<00:00, 3.26it/s]
|
||||
Loading pipeline components...: 100%|██████████████████████████| 3/3 [00:00<00:00, 3.27it/s]
|
||||
Loading pipeline components...: 100%|██████████████████████████| 3/3 [00:00<00:00, 3.25it/s]
|
||||
2025-01-10 00:51:34 [info ] starting compilation run...
|
||||
2025-01-10 00:51:35 [info ] starting compilation run...
|
||||
2025-01-10 00:51:37 [info ] starting compilation run...
|
||||
2025-01-10 00:51:37 [info ] starting compilation run...
|
||||
2025-01-10 00:52:52 [info ] compilation took 78.5155531649998 sec.
|
||||
2025-01-10 00:52:53 [info ] starting inference run...
|
||||
2025-01-10 00:52:57 [info ] compilation took 79.52986721400157 sec.
|
||||
2025-01-10 00:52:57 [info ] compilation took 81.91776501700042 sec.
|
||||
2025-01-10 00:52:57 [info ] compilation took 80.24951512600092 sec.
|
||||
2025-01-10 00:52:57 [info ] starting inference run...
|
||||
2025-01-10 00:52:57 [info ] starting inference run...
|
||||
2025-01-10 00:52:58 [info ] starting inference run...
|
||||
2025-01-10 00:53:22 [info ] inference time: 25.112665320000815
|
||||
2025-01-10 00:53:30 [info ] inference time: 7.7019307739992655
|
||||
2025-01-10 00:53:38 [info ] inference time: 7.693858365000779
|
||||
2025-01-10 00:53:46 [info ] inference time: 7.690621814001133
|
||||
2025-01-10 00:53:53 [info ] inference time: 7.679490454000188
|
||||
2025-01-10 00:54:01 [info ] inference time: 7.68949568500102
|
||||
2025-01-10 00:54:09 [info ] inference time: 7.686633744000574
|
||||
2025-01-10 00:54:16 [info ] inference time: 7.696786873999372
|
||||
2025-01-10 00:54:24 [info ] inference time: 7.691988694999964
|
||||
2025-01-10 00:54:32 [info ] inference time: 7.700649563999832
|
||||
2025-01-10 00:54:39 [info ] inference time: 7.684993574001055
|
||||
2025-01-10 00:54:47 [info ] inference time: 7.68343457499941
|
||||
2025-01-10 00:54:55 [info ] inference time: 7.667921153999487
|
||||
2025-01-10 00:55:02 [info ] inference time: 7.683585194001353
|
||||
2025-01-10 00:55:06 [info ] avg. inference over 15 iterations took 8.61202360273334 sec.
|
||||
2025-01-10 00:55:07 [info ] avg. inference over 15 iterations took 8.952725123600006 sec.
|
||||
2025-01-10 00:55:10 [info ] inference time: 7.673799695001435
|
||||
2025-01-10 00:55:10 [info ] avg. inference over 15 iterations took 8.849190365400379 sec.
|
||||
2025-01-10 00:55:10 [info ] saved metric information as /tmp/metrics_report.txt
|
||||
2025-01-10 00:55:12 [info ] avg. inference over 15 iterations took 8.940161458400205 sec.
|
||||
```
|
||||
@@ -9,7 +9,6 @@ import torch_xla.debug.metrics as met
|
||||
import torch_xla.debug.profiler as xp
|
||||
import torch_xla.distributed.xla_multiprocessing as xmp
|
||||
import torch_xla.runtime as xr
|
||||
from torch_xla.experimental.custom_kernel import FlashAttention
|
||||
|
||||
from diffusers import FluxPipeline
|
||||
|
||||
@@ -37,19 +36,6 @@ def _main(index, args, text_pipe, ckpt_id):
|
||||
ckpt_id, text_encoder=None, tokenizer=None, text_encoder_2=None, tokenizer_2=None, torch_dtype=torch.bfloat16
|
||||
).to(device0)
|
||||
flux_pipe.transformer.enable_xla_flash_attention(partition_spec=("data", None, None, None), is_flux=True)
|
||||
FlashAttention.DEFAULT_BLOCK_SIZES = {
|
||||
"block_q": 1536,
|
||||
"block_k_major": 1536,
|
||||
"block_k": 1536,
|
||||
"block_b": 1536,
|
||||
"block_q_major_dkv": 1536,
|
||||
"block_k_major_dkv": 1536,
|
||||
"block_q_dkv": 1536,
|
||||
"block_k_dkv": 1536,
|
||||
"block_q_dq": 1536,
|
||||
"block_k_dq": 1536,
|
||||
"block_k_major_dq": 1536,
|
||||
}
|
||||
|
||||
prompt = "photograph of an electronics chip in the shape of a race car with trillium written on its side"
|
||||
width = args.width
|
||||
@@ -83,14 +69,14 @@ def _main(index, args, text_pipe, ckpt_id):
|
||||
xm.set_rng_state(seed=unique_seed, device=device0)
|
||||
times = []
|
||||
logger.info("starting inference run...")
|
||||
with torch.no_grad():
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = text_pipe.encode_prompt(
|
||||
prompt=prompt, prompt_2=None, max_sequence_length=512
|
||||
)
|
||||
prompt_embeds = prompt_embeds.to(device0)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.to(device0)
|
||||
for _ in range(args.itters):
|
||||
ts = perf_counter()
|
||||
with torch.no_grad():
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = text_pipe.encode_prompt(
|
||||
prompt=prompt, prompt_2=None, max_sequence_length=512
|
||||
)
|
||||
prompt_embeds = prompt_embeds.to(device0)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.to(device0)
|
||||
|
||||
if args.profile:
|
||||
xp.trace_detached(f"localhost:{profiler_port}", str(profile_path), duration_ms=profile_duration)
|
||||
@@ -106,7 +92,7 @@ def _main(index, args, text_pipe, ckpt_id):
|
||||
if index == 0:
|
||||
logger.info(f"inference time: {inference_time}")
|
||||
times.append(inference_time)
|
||||
logger.info(f"avg. inference over {args.itters} iterations took {sum(times) / len(times)} sec.")
|
||||
logger.info(f"avg. inference over {args.itters} iterations took {sum(times)/len(times)} sec.")
|
||||
image.save(f"/tmp/inference_out-{index}.png")
|
||||
if index == 0:
|
||||
metrics_report = met.metrics_report()
|
||||
|
||||
@@ -141,7 +141,9 @@ def log_validation(vae, unet, adapter, args, accelerator, weight_dtype, step):
|
||||
validation_prompt = log["validation_prompt"]
|
||||
validation_image = log["validation_image"]
|
||||
|
||||
formatted_images = [np.asarray(validation_image)]
|
||||
formatted_images = []
|
||||
|
||||
formatted_images.append(np.asarray(validation_image))
|
||||
|
||||
for image in images:
|
||||
formatted_images.append(np.asarray(image))
|
||||
|
||||
@@ -53,18 +53,8 @@ args = parser.parse_args()
|
||||
# this is specific to `AdaLayerNormContinuous`:
|
||||
# diffusers implementation split the linear projection into the scale, shift while CogView4 split it tino shift, scale
|
||||
def swap_scale_shift(weight, dim):
|
||||
"""
|
||||
Swap the scale and shift components in the weight tensor.
|
||||
|
||||
Args:
|
||||
weight (torch.Tensor): The original weight tensor.
|
||||
dim (int): The dimension along which to split.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The modified weight tensor with scale and shift swapped.
|
||||
"""
|
||||
shift, scale = weight.chunk(2, dim=dim)
|
||||
new_weight = torch.cat([scale, shift], dim=dim)
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
return new_weight
|
||||
|
||||
|
||||
@@ -210,7 +200,6 @@ def main(args):
|
||||
"norm_num_groups": 32,
|
||||
"sample_size": 1024,
|
||||
"scaling_factor": 1.0,
|
||||
"shift_factor": 0.0,
|
||||
"force_upcast": True,
|
||||
"use_quant_conv": False,
|
||||
"use_post_quant_conv": False,
|
||||
|
||||
@@ -25,15 +25,9 @@ import argparse
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from transformers import GlmModel, PreTrainedTokenizerFast
|
||||
from transformers import GlmForCausalLM, PreTrainedTokenizerFast
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
CogView4ControlPipeline,
|
||||
CogView4Pipeline,
|
||||
CogView4Transformer2DModel,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
)
|
||||
from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
|
||||
|
||||
|
||||
@@ -118,12 +112,6 @@ parser.add_argument(
|
||||
default=128,
|
||||
help="Maximum size for positional embeddings.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--control",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to use control model.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -162,15 +150,13 @@ def convert_megatron_transformer_checkpoint_to_diffusers(
|
||||
Returns:
|
||||
dict: The converted state dictionary compatible with Diffusers.
|
||||
"""
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu")
|
||||
mega = ckpt["model"]
|
||||
|
||||
new_state_dict = {}
|
||||
|
||||
# Patch Embedding
|
||||
new_state_dict["patch_embed.proj.weight"] = mega["encoder_expand_linear.weight"].reshape(
|
||||
hidden_size, 128 if args.control else 64
|
||||
)
|
||||
new_state_dict["patch_embed.proj.weight"] = mega["encoder_expand_linear.weight"].reshape(hidden_size, 64)
|
||||
new_state_dict["patch_embed.proj.bias"] = mega["encoder_expand_linear.bias"]
|
||||
new_state_dict["patch_embed.text_proj.weight"] = mega["text_projector.weight"]
|
||||
new_state_dict["patch_embed.text_proj.bias"] = mega["text_projector.bias"]
|
||||
@@ -203,8 +189,14 @@ def convert_megatron_transformer_checkpoint_to_diffusers(
|
||||
block_prefix = f"transformer_blocks.{i}."
|
||||
|
||||
# AdaLayerNorm
|
||||
new_state_dict[block_prefix + "norm1.linear.weight"] = mega[f"decoder.layers.{i}.adaln.weight"]
|
||||
new_state_dict[block_prefix + "norm1.linear.bias"] = mega[f"decoder.layers.{i}.adaln.bias"]
|
||||
new_state_dict[block_prefix + "norm1.linear.weight"] = swap_scale_shift(
|
||||
mega[f"decoder.layers.{i}.adaln.weight"], dim=0
|
||||
)
|
||||
new_state_dict[block_prefix + "norm1.linear.bias"] = swap_scale_shift(
|
||||
mega[f"decoder.layers.{i}.adaln.bias"], dim=0
|
||||
)
|
||||
|
||||
# QKV
|
||||
qkv_weight = mega[f"decoder.layers.{i}.self_attention.linear_qkv.weight"]
|
||||
qkv_bias = mega[f"decoder.layers.{i}.self_attention.linear_qkv.bias"]
|
||||
|
||||
@@ -229,7 +221,7 @@ def convert_megatron_transformer_checkpoint_to_diffusers(
|
||||
# Attention Output
|
||||
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = mega[
|
||||
f"decoder.layers.{i}.self_attention.linear_proj.weight"
|
||||
]
|
||||
].T
|
||||
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = mega[
|
||||
f"decoder.layers.{i}.self_attention.linear_proj.bias"
|
||||
]
|
||||
@@ -260,7 +252,7 @@ def convert_cogview4_vae_checkpoint_to_diffusers(ckpt_path, vae_config):
|
||||
Returns:
|
||||
dict: The converted VAE state dictionary compatible with Diffusers.
|
||||
"""
|
||||
original_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False)["state_dict"]
|
||||
original_state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
|
||||
return convert_ldm_vae_checkpoint(original_state_dict, vae_config)
|
||||
|
||||
|
||||
@@ -294,7 +286,7 @@ def main(args):
|
||||
)
|
||||
transformer = CogView4Transformer2DModel(
|
||||
patch_size=2,
|
||||
in_channels=32 if args.control else 16,
|
||||
in_channels=16,
|
||||
num_layers=args.num_layers,
|
||||
attention_head_dim=args.attention_head_dim,
|
||||
num_attention_heads=args.num_heads,
|
||||
@@ -325,7 +317,6 @@ def main(args):
|
||||
"norm_num_groups": 32,
|
||||
"sample_size": 1024,
|
||||
"scaling_factor": 1.0,
|
||||
"shift_factor": 0.0,
|
||||
"force_upcast": True,
|
||||
"use_quant_conv": False,
|
||||
"use_post_quant_conv": False,
|
||||
@@ -340,7 +331,7 @@ def main(args):
|
||||
# Load the text encoder and tokenizer
|
||||
text_encoder_id = "THUDM/glm-4-9b-hf"
|
||||
tokenizer = PreTrainedTokenizerFast.from_pretrained(text_encoder_id)
|
||||
text_encoder = GlmModel.from_pretrained(
|
||||
text_encoder = GlmForCausalLM.from_pretrained(
|
||||
text_encoder_id,
|
||||
cache_dir=args.text_encoder_cache_dir,
|
||||
torch_dtype=torch.bfloat16 if args.dtype == "bf16" else torch.float32,
|
||||
@@ -354,22 +345,13 @@ def main(args):
|
||||
)
|
||||
|
||||
# Create the pipeline
|
||||
if args.control:
|
||||
pipe = CogView4ControlPipeline(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
else:
|
||||
pipe = CogView4Pipeline(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
pipe = CogView4Pipeline(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
# Save the converted pipeline
|
||||
pipe.save_pretrained(
|
||||
|
||||
@@ -5,7 +5,7 @@ import torch
|
||||
from safetensors.torch import load_file
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, LuminaNextDiT2DModel, LuminaPipeline
|
||||
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, LuminaNextDiT2DModel, LuminaText2ImgPipeline
|
||||
|
||||
|
||||
def main(args):
|
||||
@@ -115,7 +115,7 @@ def main(args):
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
|
||||
text_encoder = AutoModel.from_pretrained("google/gemma-2b")
|
||||
|
||||
pipeline = LuminaPipeline(
|
||||
pipeline = LuminaText2ImgPipeline(
|
||||
tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, scheduler=scheduler
|
||||
)
|
||||
pipeline.save_pretrained(args.dump_path)
|
||||
|
||||
@@ -345,7 +345,6 @@ else:
|
||||
"CogVideoXPipeline",
|
||||
"CogVideoXVideoToVideoPipeline",
|
||||
"CogView3PlusPipeline",
|
||||
"CogView4ControlPipeline",
|
||||
"CogView4Pipeline",
|
||||
"ConsisIDPipeline",
|
||||
"CycleDiffusionPipeline",
|
||||
@@ -404,9 +403,7 @@ else:
|
||||
"LEditsPPPipelineStableDiffusionXL",
|
||||
"LTXImageToVideoPipeline",
|
||||
"LTXPipeline",
|
||||
"Lumina2Pipeline",
|
||||
"Lumina2Text2ImgPipeline",
|
||||
"LuminaPipeline",
|
||||
"LuminaText2ImgPipeline",
|
||||
"MarigoldDepthPipeline",
|
||||
"MarigoldIntrinsicsPipeline",
|
||||
@@ -890,7 +887,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
CogVideoXPipeline,
|
||||
CogVideoXVideoToVideoPipeline,
|
||||
CogView3PlusPipeline,
|
||||
CogView4ControlPipeline,
|
||||
CogView4Pipeline,
|
||||
ConsisIDPipeline,
|
||||
CycleDiffusionPipeline,
|
||||
@@ -949,9 +945,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
LTXImageToVideoPipeline,
|
||||
LTXPipeline,
|
||||
Lumina2Pipeline,
|
||||
Lumina2Text2ImgPipeline,
|
||||
LuminaPipeline,
|
||||
LuminaText2ImgPipeline,
|
||||
MarigoldDepthPipeline,
|
||||
MarigoldIntrinsicsPipeline,
|
||||
|
||||
@@ -29,11 +29,16 @@ if is_accelerate_available():
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Removed PinnedGroupManager - we no longer use pinned memory to avoid CPU memory spikes
|
||||
|
||||
|
||||
# fmt: off
|
||||
_GROUP_OFFLOADING = "group_offloading"
|
||||
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
|
||||
_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
|
||||
|
||||
# Always use memory-efficient CPU offloading to minimize RAM usage
|
||||
|
||||
_SUPPORTED_PYTORCH_LAYERS = (
|
||||
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
|
||||
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
|
||||
@@ -56,7 +61,6 @@ class ModuleGroup:
|
||||
buffers: Optional[List[torch.Tensor]] = None,
|
||||
non_blocking: bool = False,
|
||||
stream: Optional[torch.cuda.Stream] = None,
|
||||
cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None,
|
||||
onload_self: bool = True,
|
||||
) -> None:
|
||||
self.modules = modules
|
||||
@@ -68,12 +72,8 @@ class ModuleGroup:
|
||||
self.buffers = buffers
|
||||
self.non_blocking = non_blocking or stream is not None
|
||||
self.stream = stream
|
||||
self.cpu_param_dict = cpu_param_dict
|
||||
self.onload_self = onload_self
|
||||
|
||||
if self.stream is not None and self.cpu_param_dict is None:
|
||||
raise ValueError("cpu_param_dict must be provided when using stream for data transfer.")
|
||||
|
||||
def onload_(self):
|
||||
r"""Onloads the group of modules to the onload_device."""
|
||||
context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream)
|
||||
@@ -82,23 +82,125 @@ class ModuleGroup:
|
||||
self.stream.synchronize()
|
||||
|
||||
with context:
|
||||
for group_module in self.modules:
|
||||
group_module.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
# Use the most efficient module-level transfer when possible
|
||||
# This approach mirrors how PyTorch handles full model transfers
|
||||
if self.modules:
|
||||
for group_module in self.modules:
|
||||
# Only onload if some parameters are not on the target device
|
||||
if any(p.device != self.onload_device for p in group_module.parameters()):
|
||||
try:
|
||||
# Try the most efficient approach using _apply
|
||||
if hasattr(group_module, "_apply"):
|
||||
# This is what module.to() uses internally
|
||||
def to_device(t):
|
||||
if t.device != self.onload_device:
|
||||
if self.onload_device.type == "cuda":
|
||||
return t.cuda(self.onload_device.index,
|
||||
non_blocking=self.non_blocking)
|
||||
else:
|
||||
return t.to(self.onload_device,
|
||||
non_blocking=self.non_blocking)
|
||||
return t
|
||||
|
||||
# Apply to all tensors without unnecessary copies
|
||||
group_module._apply(to_device)
|
||||
else:
|
||||
# Fallback to direct parameter transfer
|
||||
for param in group_module.parameters():
|
||||
if param.device != self.onload_device:
|
||||
if self.onload_device.type == "cuda":
|
||||
param.data = param.data.cuda(self.onload_device.index,
|
||||
non_blocking=self.non_blocking)
|
||||
else:
|
||||
param.data = param.data.to(self.onload_device,
|
||||
non_blocking=self.non_blocking)
|
||||
except Exception as e:
|
||||
# If optimization fails, fall back to direct parameter transfer
|
||||
logger.warning(f"Optimized onloading failed: {e}, falling back to direct method")
|
||||
for param in group_module.parameters():
|
||||
if param.device != self.onload_device:
|
||||
if self.onload_device.type == "cuda":
|
||||
param.data = param.data.cuda(self.onload_device.index,
|
||||
non_blocking=self.non_blocking)
|
||||
else:
|
||||
param.data = param.data.to(self.onload_device,
|
||||
non_blocking=self.non_blocking)
|
||||
|
||||
# Handle explicit parameters
|
||||
if self.parameters is not None:
|
||||
for param in self.parameters:
|
||||
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if param.device != self.onload_device:
|
||||
if self.onload_device.type == "cuda":
|
||||
param.data = param.data.cuda(self.onload_device.index,
|
||||
non_blocking=self.non_blocking)
|
||||
else:
|
||||
param.data = param.data.to(self.onload_device,
|
||||
non_blocking=self.non_blocking)
|
||||
|
||||
# Handle buffers
|
||||
if self.buffers is not None:
|
||||
for buffer in self.buffers:
|
||||
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if buffer.device != self.onload_device:
|
||||
if self.onload_device.type == "cuda":
|
||||
buffer.data = buffer.data.cuda(self.onload_device.index,
|
||||
non_blocking=self.non_blocking)
|
||||
else:
|
||||
buffer.data = buffer.data.to(self.onload_device,
|
||||
non_blocking=self.non_blocking)
|
||||
|
||||
def offload_(self):
|
||||
r"""Offloads the group of modules to the offload_device."""
|
||||
if self.stream is not None:
|
||||
torch.cuda.current_stream().synchronize()
|
||||
for group_module in self.modules:
|
||||
for param in group_module.parameters():
|
||||
param.data = self.cpu_param_dict[param]
|
||||
# For CPU offloading
|
||||
if self.offload_device.type == "cpu":
|
||||
# Synchronize if using stream
|
||||
if self.stream is not None:
|
||||
torch.cuda.current_stream().synchronize()
|
||||
|
||||
# Empty GPU cache before offloading to reduce memory fragmentation
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# For module groups, use a single, unified approach that is closest to
|
||||
# the behavior of model.to("cpu")
|
||||
if self.modules:
|
||||
for group_module in self.modules:
|
||||
# Check if we need to offload this module
|
||||
if any(p.device.type != "cpu" for p in group_module.parameters()):
|
||||
# Use PyTorch's built-in to() method directly, which preserves
|
||||
# memory mapping when moving to CPU
|
||||
try:
|
||||
# Non-blocking=False for CPU transfers, as it ensures memory is
|
||||
# immediately available and potentially preserves memory mapping
|
||||
group_module.to("cpu", non_blocking=False)
|
||||
except Exception as e:
|
||||
# If there's any error, fall back to parameter-level offloading
|
||||
logger.warning(f"Module-level CPU offloading failed: {e}, falling back to parameter-level")
|
||||
for param in group_module.parameters():
|
||||
if param.device.type != "cpu":
|
||||
param.data = param.data.to("cpu", non_blocking=False)
|
||||
|
||||
# Handle explicit parameters - move directly to CPU with non-blocking=False
|
||||
# which can preserve memory mapping in some PyTorch versions
|
||||
if self.parameters is not None:
|
||||
for param in self.parameters:
|
||||
if param.device.type != "cpu":
|
||||
param.data = param.data.to("cpu", non_blocking=False)
|
||||
|
||||
# Handle buffers
|
||||
if self.buffers is not None:
|
||||
for buffer in self.buffers:
|
||||
if buffer.device.type != "cpu":
|
||||
buffer.data = buffer.data.to("cpu", non_blocking=False)
|
||||
|
||||
# Let Python's normal reference counting handle cleanup
|
||||
# We don't force garbage collection to avoid slowing down inference
|
||||
|
||||
else:
|
||||
# For non-CPU offloading, synchronize if using stream
|
||||
if self.stream is not None:
|
||||
torch.cuda.current_stream().synchronize()
|
||||
|
||||
# For non-CPU offloading, use the regular approach
|
||||
for group_module in self.modules:
|
||||
group_module.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
if self.parameters is not None:
|
||||
@@ -108,6 +210,9 @@ class ModuleGroup:
|
||||
for buffer in self.buffers:
|
||||
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
|
||||
# After offloading, we can unpin the memory if configured to do so
|
||||
# We'll keep it pinned by default for better performance
|
||||
|
||||
|
||||
class GroupOffloadingHook(ModelHook):
|
||||
r"""
|
||||
@@ -129,6 +234,7 @@ class GroupOffloadingHook(ModelHook):
|
||||
|
||||
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
if self.group.offload_leader == module:
|
||||
# Offload to CPU
|
||||
self.group.offload_()
|
||||
return module
|
||||
|
||||
@@ -313,7 +419,8 @@ def apply_group_offloading(
|
||||
If True, offloading and onloading is done with non-blocking data transfer.
|
||||
use_stream (`bool`, defaults to `False`):
|
||||
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
|
||||
overlapping computation and data transfer.
|
||||
overlapping computation and data transfer. Memory-efficient CPU offloading is automatically used
|
||||
to minimize RAM usage by preserving memory mapping benefits and avoiding unnecessary copies.
|
||||
|
||||
Example:
|
||||
```python
|
||||
@@ -344,12 +451,19 @@ def apply_group_offloading(
|
||||
|
||||
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
|
||||
|
||||
# We no longer need a pinned group manager as we're not using pinned memory
|
||||
|
||||
if offload_type == "block_level":
|
||||
if num_blocks_per_group is None:
|
||||
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
|
||||
|
||||
_apply_group_offloading_block_level(
|
||||
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream
|
||||
module,
|
||||
num_blocks_per_group,
|
||||
offload_device,
|
||||
onload_device,
|
||||
non_blocking,
|
||||
stream,
|
||||
)
|
||||
elif offload_type == "leaf_level":
|
||||
_apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream)
|
||||
@@ -384,12 +498,7 @@ def _apply_group_offloading_block_level(
|
||||
for overlapping computation and data transfer.
|
||||
"""
|
||||
|
||||
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
|
||||
cpu_param_dict = None
|
||||
if stream is not None:
|
||||
for param in module.parameters():
|
||||
param.data = param.data.cpu().pin_memory()
|
||||
cpu_param_dict = {param: param.data for param in module.parameters()}
|
||||
# We no longer need a CPU parameter dictionary
|
||||
|
||||
# Create module groups for ModuleList and Sequential blocks
|
||||
modules_with_group_offloading = set()
|
||||
@@ -411,7 +520,6 @@ def _apply_group_offloading_block_level(
|
||||
onload_leader=current_modules[0],
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
cpu_param_dict=cpu_param_dict,
|
||||
onload_self=stream is None,
|
||||
)
|
||||
matched_module_groups.append(group)
|
||||
@@ -448,7 +556,6 @@ def _apply_group_offloading_block_level(
|
||||
buffers=buffers,
|
||||
non_blocking=False,
|
||||
stream=None,
|
||||
cpu_param_dict=None,
|
||||
onload_self=True,
|
||||
)
|
||||
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
|
||||
@@ -483,12 +590,7 @@ def _apply_group_offloading_leaf_level(
|
||||
for overlapping computation and data transfer.
|
||||
"""
|
||||
|
||||
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
|
||||
cpu_param_dict = None
|
||||
if stream is not None:
|
||||
for param in module.parameters():
|
||||
param.data = param.data.cpu().pin_memory()
|
||||
cpu_param_dict = {param: param.data for param in module.parameters()}
|
||||
# We no longer need a CPU parameter dictionary
|
||||
|
||||
# Create module groups for leaf modules and apply group offloading hooks
|
||||
modules_with_group_offloading = set()
|
||||
@@ -503,7 +605,6 @@ def _apply_group_offloading_leaf_level(
|
||||
onload_leader=submodule,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
cpu_param_dict=cpu_param_dict,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_group_offloading_hook(submodule, group, None)
|
||||
@@ -548,7 +649,6 @@ def _apply_group_offloading_leaf_level(
|
||||
buffers=buffers,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
cpu_param_dict=cpu_param_dict,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_group_offloading_hook(parent_module, group, None)
|
||||
@@ -567,7 +667,6 @@ def _apply_group_offloading_leaf_level(
|
||||
buffers=None,
|
||||
non_blocking=False,
|
||||
stream=None,
|
||||
cpu_param_dict=None,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
|
||||
|
||||
@@ -804,7 +804,9 @@ class SD3IPAdapterMixin:
|
||||
}
|
||||
|
||||
self.register_modules(
|
||||
feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs),
|
||||
feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs).to(
|
||||
self.device, dtype=self.dtype
|
||||
),
|
||||
image_encoder=SiglipVisionModel.from_pretrained(
|
||||
image_encoder_subfolder, torch_dtype=self.dtype, **kwargs
|
||||
).to(self.device),
|
||||
|
||||
@@ -423,12 +423,8 @@ def _load_lora_into_text_encoder(
|
||||
# Unsafe code />
|
||||
|
||||
if prefix is not None and not state_dict:
|
||||
logger.warning(
|
||||
f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. "
|
||||
"This is safe to ignore if LoRA state dict didn't originally have any "
|
||||
f"{text_encoder.__class__.__name__} related params. You can also try specifying `prefix=None` "
|
||||
"to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
|
||||
"https://github.com/huggingface/diffusers/issues/new"
|
||||
logger.info(
|
||||
f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. This is safe to ignore if LoRA state dict didn't originally have any {text_encoder.__class__.__name__} related params. Open an issue if you think it's unexpected: https://github.com/huggingface/diffusers/issues/new"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1348,56 +1348,3 @@ def _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict):
|
||||
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
||||
converted_state_dict = {}
|
||||
original_state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()}
|
||||
|
||||
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in original_state_dict})
|
||||
is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
|
||||
|
||||
for i in range(num_blocks):
|
||||
# Self-attention
|
||||
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
|
||||
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_A.weight"] = original_state_dict.pop(
|
||||
f"blocks.{i}.self_attn.{o}.lora_A.weight"
|
||||
)
|
||||
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = original_state_dict.pop(
|
||||
f"blocks.{i}.self_attn.{o}.lora_B.weight"
|
||||
)
|
||||
|
||||
# Cross-attention
|
||||
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
|
||||
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
|
||||
f"blocks.{i}.cross_attn.{o}.lora_A.weight"
|
||||
)
|
||||
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
|
||||
f"blocks.{i}.cross_attn.{o}.lora_B.weight"
|
||||
)
|
||||
|
||||
if is_i2v_lora:
|
||||
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
|
||||
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
|
||||
f"blocks.{i}.cross_attn.{o}.lora_A.weight"
|
||||
)
|
||||
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
|
||||
f"blocks.{i}.cross_attn.{o}.lora_B.weight"
|
||||
)
|
||||
|
||||
# FFN
|
||||
for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
|
||||
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_A.weight"] = original_state_dict.pop(
|
||||
f"blocks.{i}.{o}.lora_A.weight"
|
||||
)
|
||||
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.weight"] = original_state_dict.pop(
|
||||
f"blocks.{i}.{o}.lora_B.weight"
|
||||
)
|
||||
|
||||
if len(original_state_dict) > 0:
|
||||
raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}")
|
||||
|
||||
for key in list(converted_state_dict.keys()):
|
||||
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
@@ -42,7 +42,6 @@ from .lora_conversion_utils import (
|
||||
_convert_kohya_flux_lora_to_diffusers,
|
||||
_convert_non_diffusers_lora_to_diffusers,
|
||||
_convert_non_diffusers_lumina2_lora_to_diffusers,
|
||||
_convert_non_diffusers_wan_lora_to_diffusers,
|
||||
_convert_xlabs_flux_lora_to_diffusers,
|
||||
_maybe_map_sgm_blocks_to_diffusers,
|
||||
)
|
||||
@@ -452,11 +451,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
)
|
||||
|
||||
def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs):
|
||||
@@ -477,7 +472,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
||||
LoRA parameters then it won't have any effect.
|
||||
"""
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
super().unfuse_lora(components=components)
|
||||
|
||||
|
||||
class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
@@ -896,11 +891,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
)
|
||||
|
||||
def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_encoder_2"], **kwargs):
|
||||
@@ -921,7 +912,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
||||
LoRA parameters then it won't have any effect.
|
||||
"""
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
super().unfuse_lora(components=components)
|
||||
|
||||
|
||||
class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
@@ -1299,11 +1290,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer
|
||||
@@ -1325,7 +1312,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
||||
LoRA parameters then it won't have any effect.
|
||||
"""
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
super().unfuse_lora(components=components)
|
||||
|
||||
|
||||
class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
@@ -1841,11 +1828,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
|
||||
super().fuse_lora(
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
)
|
||||
|
||||
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
|
||||
@@ -1866,7 +1849,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
|
||||
transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
|
||||
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
super().unfuse_lora(components=components)
|
||||
|
||||
# We override this here account for `_transformer_norm_layers` and `_overwritten_params`.
|
||||
def unload_lora_weights(self, reset_to_overwritten_params=False):
|
||||
@@ -2565,11 +2548,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
)
|
||||
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
@@ -2587,7 +2566,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
"""
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
super().unfuse_lora(components=components)
|
||||
|
||||
|
||||
class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
||||
@@ -2873,11 +2852,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
@@ -2896,7 +2871,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
"""
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
super().unfuse_lora(components=components)
|
||||
|
||||
|
||||
class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
@@ -3182,11 +3157,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
@@ -3205,7 +3176,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
"""
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
super().unfuse_lora(components=components)
|
||||
|
||||
|
||||
class SanaLoraLoaderMixin(LoraBaseMixin):
|
||||
@@ -3491,11 +3462,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
@@ -3514,7 +3481,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
"""
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
super().unfuse_lora(components=components)
|
||||
|
||||
|
||||
class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
@@ -3803,11 +3770,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
@@ -3826,7 +3789,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
"""
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
super().unfuse_lora(components=components)
|
||||
|
||||
|
||||
class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
||||
@@ -4116,11 +4079,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
|
||||
@@ -4139,7 +4098,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
"""
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
super().unfuse_lora(components=components)
|
||||
|
||||
|
||||
class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
@@ -4152,6 +4111,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
@@ -4238,8 +4198,6 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
user_agent=user_agent,
|
||||
allow_pickle=allow_pickle,
|
||||
)
|
||||
if any(k.startswith("diffusion_model.") for k in state_dict):
|
||||
state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict)
|
||||
|
||||
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
||||
if is_dora_scale_present:
|
||||
@@ -4426,11 +4384,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
@@ -4449,7 +4403,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
"""
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
super().unfuse_lora(components=components)
|
||||
|
||||
|
||||
class CogView4LoraLoaderMixin(LoraBaseMixin):
|
||||
@@ -4735,11 +4689,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
@@ -4758,7 +4708,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
"""
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
super().unfuse_lora(components=components)
|
||||
|
||||
|
||||
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
|
||||
@@ -354,12 +354,8 @@ class PeftAdapterMixin:
|
||||
# Unsafe code />
|
||||
|
||||
if prefix is not None and not state_dict:
|
||||
logger.warning(
|
||||
f"No LoRA keys associated to {self.__class__.__name__} found with the {prefix=}. "
|
||||
"This is safe to ignore if LoRA state dict didn't originally have any "
|
||||
f"{self.__class__.__name__} related params. You can also try specifying `prefix=None` "
|
||||
"to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
|
||||
"https://github.com/huggingface/diffusers/issues/new"
|
||||
logger.info(
|
||||
f"No LoRA keys associated to {self.__class__.__name__} found with the {prefix=}. This is safe to ignore if LoRA state dict didn't originally have any {self.__class__.__name__} related params. Open an issue if you think it's unexpected: https://github.com/huggingface/diffusers/issues/new"
|
||||
)
|
||||
|
||||
def save_lora_adapter(
|
||||
|
||||
@@ -741,14 +741,10 @@ class Attention(nn.Module):
|
||||
|
||||
if out_dim == 3:
|
||||
if attention_mask.shape[0] < batch_size * head_size:
|
||||
attention_mask = attention_mask.repeat_interleave(
|
||||
head_size, dim=0, output_size=attention_mask.shape[0] * head_size
|
||||
)
|
||||
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
|
||||
elif out_dim == 4:
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
attention_mask = attention_mask.repeat_interleave(
|
||||
head_size, dim=1, output_size=attention_mask.shape[1] * head_size
|
||||
)
|
||||
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
|
||||
|
||||
return attention_mask
|
||||
|
||||
@@ -2339,9 +2335,7 @@ class FluxAttnProcessor2_0:
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
@@ -3710,10 +3704,8 @@ class StableAudioAttnProcessor2_0:
|
||||
if kv_heads != attn.heads:
|
||||
# if GQA or MQA, repeat the key/value heads to reach the number of query heads.
|
||||
heads_per_kv_head = attn.heads // kv_heads
|
||||
key = torch.repeat_interleave(key, heads_per_kv_head, dim=1, output_size=key.shape[1] * heads_per_kv_head)
|
||||
value = torch.repeat_interleave(
|
||||
value, heads_per_kv_head, dim=1, output_size=value.shape[1] * heads_per_kv_head
|
||||
)
|
||||
key = torch.repeat_interleave(key, heads_per_kv_head, dim=1)
|
||||
value = torch.repeat_interleave(value, heads_per_kv_head, dim=1)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
|
||||
@@ -190,7 +190,7 @@ class DCUpBlock2d(nn.Module):
|
||||
x = F.pixel_shuffle(x, self.factor)
|
||||
|
||||
if self.shortcut:
|
||||
y = hidden_states.repeat_interleave(self.repeats, dim=1, output_size=hidden_states.shape[1] * self.repeats)
|
||||
y = hidden_states.repeat_interleave(self.repeats, dim=1)
|
||||
y = F.pixel_shuffle(y, self.factor)
|
||||
hidden_states = x + y
|
||||
else:
|
||||
@@ -361,9 +361,7 @@ class Decoder(nn.Module):
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if self.in_shortcut:
|
||||
x = hidden_states.repeat_interleave(
|
||||
self.in_shortcut_repeats, dim=1, output_size=hidden_states.shape[1] * self.in_shortcut_repeats
|
||||
)
|
||||
x = hidden_states.repeat_interleave(self.in_shortcut_repeats, dim=1)
|
||||
hidden_states = self.conv_in(hidden_states) + x
|
||||
else:
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
@@ -103,7 +103,7 @@ class AllegroTemporalConvLayer(nn.Module):
|
||||
if self.down_sample:
|
||||
identity = hidden_states[:, :, ::2]
|
||||
elif self.up_sample:
|
||||
identity = hidden_states.repeat_interleave(2, dim=2, output_size=hidden_states.shape[2] * 2)
|
||||
identity = hidden_states.repeat_interleave(2, dim=2)
|
||||
else:
|
||||
identity = hidden_states
|
||||
|
||||
|
||||
@@ -426,9 +426,7 @@ class FourierFeatures(nn.Module):
|
||||
w = w.repeat(num_channels)[None, :, None, None, None] # [1, num_channels * num_freqs, 1, 1, 1]
|
||||
|
||||
# Interleaved repeat of input channels to match w
|
||||
h = inputs.repeat_interleave(
|
||||
num_freqs, dim=1, output_size=inputs.shape[1] * num_freqs
|
||||
) # [B, C * num_freqs, T, H, W]
|
||||
h = inputs.repeat_interleave(num_freqs, dim=1) # [B, C * num_freqs, T, H, W]
|
||||
# Scale channels by frequency.
|
||||
h = w * h
|
||||
|
||||
|
||||
@@ -687,7 +687,7 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
t_emb = t_emb.to(dtype=sample.dtype)
|
||||
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
emb = emb.repeat_interleave(sample_num_frames, dim=0, output_size=emb.shape[0] * sample_num_frames)
|
||||
emb = emb.repeat_interleave(sample_num_frames, dim=0)
|
||||
|
||||
# 2. pre-process
|
||||
batch_size, channels, num_frames, height, width = sample.shape
|
||||
|
||||
@@ -139,9 +139,7 @@ def get_3d_sincos_pos_embed(
|
||||
|
||||
# 3. Concat
|
||||
pos_embed_spatial = pos_embed_spatial[None, :, :]
|
||||
pos_embed_spatial = pos_embed_spatial.repeat_interleave(
|
||||
temporal_size, dim=0, output_size=pos_embed_spatial.shape[0] * temporal_size
|
||||
) # [T, H*W, D // 4 * 3]
|
||||
pos_embed_spatial = pos_embed_spatial.repeat_interleave(temporal_size, dim=0) # [T, H*W, D // 4 * 3]
|
||||
|
||||
pos_embed_temporal = pos_embed_temporal[:, None, :]
|
||||
pos_embed_temporal = pos_embed_temporal.repeat_interleave(
|
||||
@@ -1154,13 +1152,10 @@ def get_1d_rotary_pos_embed(
|
||||
/ linear_factor
|
||||
) # [D/2]
|
||||
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
|
||||
is_npu = freqs.device.type == "npu"
|
||||
if is_npu:
|
||||
freqs = freqs.float()
|
||||
if use_real and repeat_interleave_real:
|
||||
# flux, hunyuan-dit, cogvideox
|
||||
freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
|
||||
freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
|
||||
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
|
||||
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
elif use_real:
|
||||
# stable audio, allegro
|
||||
|
||||
@@ -227,17 +227,13 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
|
||||
# Prepare text embeddings for spatial block
|
||||
# batch_size num_tokens hidden_size -> (batch_size * num_frame) num_tokens hidden_size
|
||||
encoder_hidden_states = self.caption_projection(encoder_hidden_states) # 3 120 1152
|
||||
encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave(
|
||||
num_frame, dim=0, output_size=encoder_hidden_states.shape[0] * num_frame
|
||||
).view(-1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1])
|
||||
encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave(num_frame, dim=0).view(
|
||||
-1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1]
|
||||
)
|
||||
|
||||
# Prepare timesteps for spatial and temporal block
|
||||
timestep_spatial = timestep.repeat_interleave(
|
||||
num_frame, dim=0, output_size=timestep.shape[0] * num_frame
|
||||
).view(-1, timestep.shape[-1])
|
||||
timestep_temp = timestep.repeat_interleave(
|
||||
num_patches, dim=0, output_size=timestep.shape[0] * num_patches
|
||||
).view(-1, timestep.shape[-1])
|
||||
timestep_spatial = timestep.repeat_interleave(num_frame, dim=0).view(-1, timestep.shape[-1])
|
||||
timestep_temp = timestep.repeat_interleave(num_patches, dim=0).view(-1, timestep.shape[-1])
|
||||
|
||||
# Spatial and temporal transformer blocks
|
||||
for i, (spatial_block, temp_block) in enumerate(
|
||||
@@ -303,9 +299,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
|
||||
).permute(0, 2, 1, 3)
|
||||
hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1])
|
||||
|
||||
embedded_timestep = embedded_timestep.repeat_interleave(
|
||||
num_frame, dim=0, output_size=embedded_timestep.shape[0] * num_frame
|
||||
).view(-1, embedded_timestep.shape[-1])
|
||||
embedded_timestep = embedded_timestep.repeat_interleave(num_frame, dim=0).view(-1, embedded_timestep.shape[-1])
|
||||
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
# Modulation
|
||||
|
||||
@@ -353,11 +353,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
|
||||
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
|
||||
attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
|
||||
attention_mask = attention_mask.repeat_interleave(
|
||||
self.config.num_attention_heads,
|
||||
dim=0,
|
||||
output_size=attention_mask.shape[0] * self.config.num_attention_heads,
|
||||
)
|
||||
attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
|
||||
|
||||
if self.norm_in is not None:
|
||||
hidden_states = self.norm_in(hidden_states)
|
||||
|
||||
@@ -23,7 +23,6 @@ from ...loaders import PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
@@ -127,8 +126,7 @@ class CogView4AttnProcessor:
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
|
||||
batch_size, image_seq_length, embed_dim = hidden_states.shape
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
# 1. QKV projections
|
||||
@@ -158,15 +156,6 @@ class CogView4AttnProcessor:
|
||||
)
|
||||
|
||||
# 4. Attention
|
||||
if attention_mask is not None:
|
||||
text_attention_mask = attention_mask.float().to(query.device)
|
||||
actual_text_seq_length = text_attention_mask.size(1)
|
||||
new_attention_mask = torch.zeros((batch_size, text_seq_length + image_seq_length), device=query.device)
|
||||
new_attention_mask[:, :actual_text_seq_length] = text_attention_mask
|
||||
new_attention_mask = new_attention_mask.unsqueeze(2)
|
||||
attention_mask_matrix = new_attention_mask @ new_attention_mask.transpose(1, 2)
|
||||
attention_mask = (attention_mask_matrix > 0).unsqueeze(1).to(query.dtype)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
@@ -214,8 +203,6 @@ class CogView4TransformerBlock(nn.Module):
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
# 1. Timestep conditioning
|
||||
(
|
||||
@@ -236,8 +223,6 @@ class CogView4TransformerBlock(nn.Module):
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1)
|
||||
encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1)
|
||||
@@ -304,7 +289,7 @@ class CogView4RotaryPosEmbed(nn.Module):
|
||||
return (freqs.cos(), freqs.sin())
|
||||
|
||||
|
||||
class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
|
||||
class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
r"""
|
||||
Args:
|
||||
patch_size (`int`, defaults to `2`):
|
||||
@@ -401,8 +386,6 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
|
||||
crop_coords: torch.Tensor,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
@@ -438,11 +421,11 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
|
||||
for block in self.transformer_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
||||
block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs
|
||||
block, hidden_states, encoder_hidden_states, temb, image_rotary_emb
|
||||
)
|
||||
else:
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs
|
||||
hidden_states, encoder_hidden_states, temb, image_rotary_emb
|
||||
)
|
||||
|
||||
# 4. Output norm & projection
|
||||
|
||||
@@ -441,14 +441,6 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
||||
|
||||
# 5. Output norm, projection & unpatchify
|
||||
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
|
||||
|
||||
# Move the shift and scale tensors to the same device as hidden_states.
|
||||
# When using multi-GPU inference via accelerate these will be on the
|
||||
# first device rather than the last device, which hidden_states ends up
|
||||
# on.
|
||||
shift = shift.to(hidden_states.device)
|
||||
scale = scale.to(hidden_states.device)
|
||||
|
||||
hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
|
||||
@@ -638,10 +638,8 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
t_emb = t_emb.to(dtype=self.dtype)
|
||||
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames)
|
||||
encoder_hidden_states = encoder_hidden_states.repeat_interleave(
|
||||
num_frames, dim=0, output_size=encoder_hidden_states.shape[0] * num_frames
|
||||
)
|
||||
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
|
||||
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
|
||||
|
||||
# 2. pre-process
|
||||
sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
|
||||
|
||||
@@ -592,7 +592,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
|
||||
# 3. time + FPS embeddings.
|
||||
emb = t_emb + fps_emb
|
||||
emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames)
|
||||
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
|
||||
|
||||
# 4. context embeddings.
|
||||
# The context embeddings consist of both text embeddings from the input prompt
|
||||
@@ -620,7 +620,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
image_emb = self.context_embedding(image_embeddings)
|
||||
image_emb = image_emb.view(-1, self.config.in_channels, self.config.cross_attention_dim)
|
||||
context_emb = torch.cat([context_emb, image_emb], dim=1)
|
||||
context_emb = context_emb.repeat_interleave(num_frames, dim=0, output_size=context_emb.shape[0] * num_frames)
|
||||
context_emb = context_emb.repeat_interleave(repeats=num_frames, dim=0)
|
||||
|
||||
image_latents = image_latents.permute(0, 2, 1, 3, 4).reshape(
|
||||
image_latents.shape[0] * image_latents.shape[2],
|
||||
|
||||
@@ -2059,7 +2059,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
|
||||
aug_emb = self.add_embedding(add_embeds)
|
||||
|
||||
emb = emb if aug_emb is None else emb + aug_emb
|
||||
emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames)
|
||||
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
|
||||
|
||||
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
|
||||
if "image_embeds" not in added_cond_kwargs:
|
||||
@@ -2068,10 +2068,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
|
||||
)
|
||||
image_embeds = added_cond_kwargs.get("image_embeds")
|
||||
image_embeds = self.encoder_hid_proj(image_embeds)
|
||||
image_embeds = [
|
||||
image_embed.repeat_interleave(num_frames, dim=0, output_size=image_embed.shape[0] * num_frames)
|
||||
for image_embed in image_embeds
|
||||
]
|
||||
image_embeds = [image_embed.repeat_interleave(repeats=num_frames, dim=0) for image_embed in image_embeds]
|
||||
encoder_hidden_states = (encoder_hidden_states, image_embeds)
|
||||
|
||||
# 2. pre-process
|
||||
|
||||
@@ -431,11 +431,9 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
|
||||
sample = sample.flatten(0, 1)
|
||||
# Repeat the embeddings num_video_frames times
|
||||
# emb: [batch, channels] -> [batch * frames, channels]
|
||||
emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames)
|
||||
emb = emb.repeat_interleave(num_frames, dim=0)
|
||||
# encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
|
||||
encoder_hidden_states = encoder_hidden_states.repeat_interleave(
|
||||
num_frames, dim=0, output_size=encoder_hidden_states.shape[0] * num_frames
|
||||
)
|
||||
encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
|
||||
|
||||
# 2. pre-process
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
@@ -154,7 +154,7 @@ else:
|
||||
"CogVideoXFunControlPipeline",
|
||||
]
|
||||
_import_structure["cogview3"] = ["CogView3PlusPipeline"]
|
||||
_import_structure["cogview4"] = ["CogView4Pipeline", "CogView4ControlPipeline"]
|
||||
_import_structure["cogview4"] = ["CogView4Pipeline"]
|
||||
_import_structure["consisid"] = ["ConsisIDPipeline"]
|
||||
_import_structure["controlnet"].extend(
|
||||
[
|
||||
@@ -265,8 +265,8 @@ else:
|
||||
)
|
||||
_import_structure["latte"] = ["LattePipeline"]
|
||||
_import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline"]
|
||||
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
|
||||
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
|
||||
_import_structure["lumina"] = ["LuminaText2ImgPipeline"]
|
||||
_import_structure["lumina2"] = ["Lumina2Text2ImgPipeline"]
|
||||
_import_structure["marigold"].extend(
|
||||
[
|
||||
"MarigoldDepthPipeline",
|
||||
@@ -511,7 +511,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
CogVideoXVideoToVideoPipeline,
|
||||
)
|
||||
from .cogview3 import CogView3PlusPipeline
|
||||
from .cogview4 import CogView4ControlPipeline, CogView4Pipeline
|
||||
from .cogview4 import CogView4Pipeline
|
||||
from .consisid import ConsisIDPipeline
|
||||
from .controlnet import (
|
||||
BlipDiffusionControlNetPipeline,
|
||||
@@ -619,8 +619,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
)
|
||||
from .ltx import LTXImageToVideoPipeline, LTXPipeline
|
||||
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
|
||||
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
|
||||
from .lumina import LuminaText2ImgPipeline
|
||||
from .lumina2 import Lumina2Text2ImgPipeline
|
||||
from .marigold import (
|
||||
MarigoldDepthPipeline,
|
||||
MarigoldIntrinsicsPipeline,
|
||||
|
||||
@@ -22,7 +22,7 @@ from ..models.controlnets import ControlNetUnionModel
|
||||
from ..utils import is_sentencepiece_available
|
||||
from .aura_flow import AuraFlowPipeline
|
||||
from .cogview3 import CogView3PlusPipeline
|
||||
from .cogview4 import CogView4ControlPipeline, CogView4Pipeline
|
||||
from .cogview4 import CogView4Pipeline
|
||||
from .controlnet import (
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
StableDiffusionControlNetInpaintPipeline,
|
||||
@@ -69,8 +69,8 @@ from .kandinsky2_2 import (
|
||||
)
|
||||
from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline
|
||||
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
|
||||
from .lumina import LuminaPipeline
|
||||
from .lumina2 import Lumina2Pipeline
|
||||
from .lumina import LuminaText2ImgPipeline
|
||||
from .lumina2 import Lumina2Text2ImgPipeline
|
||||
from .pag import (
|
||||
HunyuanDiTPAGPipeline,
|
||||
PixArtSigmaPAGPipeline,
|
||||
@@ -141,11 +141,10 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("flux", FluxPipeline),
|
||||
("flux-control", FluxControlPipeline),
|
||||
("flux-controlnet", FluxControlNetPipeline),
|
||||
("lumina", LuminaPipeline),
|
||||
("lumina2", Lumina2Pipeline),
|
||||
("lumina", LuminaText2ImgPipeline),
|
||||
("lumina2", Lumina2Text2ImgPipeline),
|
||||
("cogview3", CogView3PlusPipeline),
|
||||
("cogview4", CogView4Pipeline),
|
||||
("cogview4-control", CogView4ControlPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -23,7 +23,6 @@ except OptionalDependencyNotAvailable:
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_cogview4"] = ["CogView4Pipeline"]
|
||||
_import_structure["pipeline_cogview4_control"] = ["CogView4ControlPipeline"]
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
@@ -32,7 +31,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_cogview4 import CogView4Pipeline
|
||||
from .pipeline_cogview4_control import CogView4ControlPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
@@ -389,18 +389,14 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
@@ -537,7 +533,6 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
# Default call parameters
|
||||
@@ -615,7 +610,6 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = latents.to(transformer_dtype)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
@@ -667,8 +661,6 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
|
||||
image = self.vae.decode(latents, return_dict=False, generator=generator)[0]
|
||||
|
||||
@@ -1,727 +0,0 @@
|
||||
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, GlmModel
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...models import AutoencoderKL, CogView4Transformer2DModel
|
||||
from ...pipelines.pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from .pipeline_output import CogView4PipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import CogView4ControlPipeline
|
||||
|
||||
>>> pipe = CogView4ControlPipeline.from_pretrained("THUDM/CogView4-6B-Control", torch_dtype=torch.bfloat16)
|
||||
>>> control_image = load_image(
|
||||
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
||||
... )
|
||||
>>> prompt = "A bird in space"
|
||||
>>> image = pipe(prompt, control_image=control_image, height=1024, width=1024, guidance_scale=3.5).images[0]
|
||||
>>> image.save("cogview4-control.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.cogview4.pipeline_cogview4.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
base_shift: float = 0.25,
|
||||
max_shift: float = 0.75,
|
||||
) -> float:
|
||||
m = (image_seq_len / base_seq_len) ** 0.5
|
||||
mu = m * max_shift + base_shift
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class CogView4ControlPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using CogView4.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`GLMModel`]):
|
||||
Frozen text-encoder. CogView4 uses [glm-4-9b-hf](https://huggingface.co/THUDM/glm-4-9b-hf).
|
||||
tokenizer (`PreTrainedTokenizer`):
|
||||
Tokenizer of class
|
||||
[PreTrainedTokenizer](https://huggingface.co/docs/transformers/main/en/main_classes/tokenizer#transformers.PreTrainedTokenizer).
|
||||
transformer ([`CogView4Transformer2DModel`]):
|
||||
A text conditioned `CogView4Transformer2DModel` to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
"""
|
||||
|
||||
_optional_components = []
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: AutoTokenizer,
|
||||
text_encoder: GlmModel,
|
||||
vae: AutoencoderKL,
|
||||
transformer: CogView4Transformer2DModel,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
# Copied from diffusers.pipelines.cogview4.pipeline_cogview4.CogView4Pipeline._get_glm_embeds
|
||||
def _get_glm_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
max_sequence_length: int = 1024,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="longest", # not use max length
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
current_length = text_input_ids.shape[1]
|
||||
pad_length = (16 - (current_length % 16)) % 16
|
||||
if pad_length > 0:
|
||||
pad_ids = torch.full(
|
||||
(text_input_ids.shape[0], pad_length),
|
||||
fill_value=self.tokenizer.pad_token_id,
|
||||
dtype=text_input_ids.dtype,
|
||||
device=text_input_ids.device,
|
||||
)
|
||||
text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(self.text_encoder.device), output_hidden_states=True
|
||||
).hidden_states[-2]
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.cogview4.pipeline_cogview4.CogView4Pipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
max_sequence_length: int = 1024,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier free guidance or not.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of images that should be generated per prompt. torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
max_sequence_length (`int`, defaults to `1024`):
|
||||
Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self._get_glm_embeds(prompt, max_sequence_length, device, dtype)
|
||||
|
||||
seq_len = prompt_embeds.size(1)
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds = self._get_glm_embeds(negative_prompt, max_sequence_length, device, dtype)
|
||||
|
||||
seq_len = negative_prompt_embeds.size(1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
if latents is not None:
|
||||
return latents.to(device)
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
int(height) // self.vae_scale_factor,
|
||||
int(width) // self.vae_scale_factor,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
return latents
|
||||
|
||||
def prepare_image(
|
||||
self,
|
||||
image,
|
||||
width,
|
||||
height,
|
||||
batch_size,
|
||||
num_images_per_prompt,
|
||||
device,
|
||||
dtype,
|
||||
do_classifier_free_guidance=False,
|
||||
guess_mode=False,
|
||||
):
|
||||
if isinstance(image, torch.Tensor):
|
||||
pass
|
||||
else:
|
||||
image = self.image_processor.preprocess(image, height=height, width=width)
|
||||
|
||||
image_batch_size = image.shape[0]
|
||||
|
||||
if image_batch_size == 1:
|
||||
repeat_by = batch_size
|
||||
else:
|
||||
# image batch size is the same as prompt batch size
|
||||
repeat_by = num_images_per_prompt
|
||||
|
||||
image = image.repeat_interleave(repeat_by, dim=0, output_size=image.shape[0] * repeat_by)
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
if do_classifier_free_guidance and not guess_mode:
|
||||
image = torch.cat([image] * 2)
|
||||
|
||||
return image
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
):
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
control_image: PipelineImageInput = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
guidance_scale: float = 5.0,
|
||||
num_images_per_prompt: int = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
original_size: Optional[Tuple[int, int]] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
output_type: str = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 1024,
|
||||
) -> Union[CogView4PipelineOutput, Tuple]:
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image. If not provided, it is set to 1024.
|
||||
width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image. If not provided it is set to 1024.
|
||||
num_inference_steps (`int`, *optional*, defaults to `50`):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
guidance_scale (`float`, *optional*, defaults to `5.0`):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to `1`):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
||||
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
|
||||
explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
||||
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
||||
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
|
||||
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
|
||||
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
||||
of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int`, defaults to `224`):
|
||||
Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] or `tuple`:
|
||||
[`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
height = height or self.transformer.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.transformer.config.sample_size * self.vae_scale_factor
|
||||
|
||||
original_size = original_size or (height, width)
|
||||
target_size = (height, width)
|
||||
|
||||
# Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
# Default call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
self.do_classifier_free_guidance,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Prepare latents
|
||||
latent_channels = self.transformer.config.in_channels // 2
|
||||
|
||||
control_image = self.prepare_image(
|
||||
image=control_image,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=self.vae.dtype,
|
||||
)
|
||||
height, width = control_image.shape[-2:]
|
||||
|
||||
vae_shift_factor = 0
|
||||
|
||||
control_image = self.vae.encode(control_image).latent_dist.sample()
|
||||
control_image = (control_image - vae_shift_factor) * self.vae.config.scaling_factor
|
||||
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
latent_channels,
|
||||
height,
|
||||
width,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# Prepare additional timestep conditions
|
||||
original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype, device=device)
|
||||
target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=device)
|
||||
crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype, device=device)
|
||||
|
||||
original_size = original_size.repeat(batch_size * num_images_per_prompt, 1)
|
||||
target_size = target_size.repeat(batch_size * num_images_per_prompt, 1)
|
||||
crops_coords_top_left = crops_coords_top_left.repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
# Prepare timesteps
|
||||
image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // (
|
||||
self.transformer.config.patch_size**2
|
||||
)
|
||||
|
||||
timesteps = (
|
||||
np.linspace(self.scheduler.config.num_train_timesteps, 1.0, num_inference_steps)
|
||||
if timesteps is None
|
||||
else np.array(timesteps)
|
||||
)
|
||||
timesteps = timesteps.astype(np.int64).astype(np.float32)
|
||||
sigmas = timesteps / self.scheduler.config.num_train_timesteps if sigmas is None else sigmas
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.scheduler.config.get("base_image_seq_len", 256),
|
||||
self.scheduler.config.get("base_shift", 0.25),
|
||||
self.scheduler.config.get("max_shift", 0.75),
|
||||
)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu
|
||||
)
|
||||
self._num_timesteps = len(timesteps)
|
||||
# Denoising loop
|
||||
transformer_dtype = self.transformer.dtype
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = torch.cat([latents, control_image], dim=1).to(transformer_dtype)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0])
|
||||
|
||||
noise_pred_cond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
original_size=original_size,
|
||||
target_size=target_size,
|
||||
crop_coords=crops_coords_top_left,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
timestep=timestep,
|
||||
original_size=original_size,
|
||||
target_size=target_size,
|
||||
crop_coords=crops_coords_top_left,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
||||
else:
|
||||
noise_pred = noise_pred_cond
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
# call the callback, if provided
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, self.scheduler.sigmas[i], callback_kwargs)
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
|
||||
image = self.vae.decode(latents, return_dict=False, generator=generator)[0]
|
||||
else:
|
||||
image = latents
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return CogView4PipelineOutput(images=image)
|
||||
@@ -22,7 +22,7 @@ except OptionalDependencyNotAvailable:
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
|
||||
_import_structure["pipeline_lumina"] = ["LuminaText2ImgPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
@@ -32,7 +32,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_lumina import LuminaPipeline, LuminaText2ImgPipeline
|
||||
from .pipeline_lumina import LuminaText2ImgPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
@@ -30,7 +30,6 @@ from ...models.transformers.lumina_nextdit2d import LuminaNextDiT2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import (
|
||||
BACKENDS_MAPPING,
|
||||
deprecate,
|
||||
is_bs4_available,
|
||||
is_ftfy_available,
|
||||
is_torch_xla_available,
|
||||
@@ -61,9 +60,11 @@ EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import LuminaPipeline
|
||||
>>> from diffusers import LuminaText2ImgPipeline
|
||||
|
||||
>>> pipe = LuminaPipeline.from_pretrained("Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16)
|
||||
>>> pipe = LuminaText2ImgPipeline.from_pretrained(
|
||||
... "Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
>>> # Enable memory optimizations.
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
|
||||
@@ -133,7 +134,7 @@ def retrieve_timesteps(
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class LuminaPipeline(DiffusionPipeline):
|
||||
class LuminaText2ImgPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Lumina-T2I.
|
||||
|
||||
@@ -931,23 +932,3 @@ class LuminaPipeline(DiffusionPipeline):
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
|
||||
|
||||
class LuminaText2ImgPipeline(LuminaPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
transformer: LuminaNextDiT2DModel,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: GemmaPreTrainedModel,
|
||||
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
|
||||
):
|
||||
deprecation_message = "`LuminaText2ImgPipeline` has been renamed to `LuminaPipeline` and will be removed in a future version. Please use `LuminaPipeline` instead."
|
||||
deprecate("diffusers.pipelines.lumina.pipeline_lumina.LuminaText2ImgPipeline", "0.34", deprecation_message)
|
||||
super().__init__(
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
@@ -22,7 +22,7 @@ except OptionalDependencyNotAvailable:
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
|
||||
_import_structure["pipeline_lumina2"] = ["Lumina2Text2ImgPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
@@ -32,7 +32,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
|
||||
from .pipeline_lumina2 import Lumina2Text2ImgPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
@@ -25,7 +25,6 @@ from ...models import AutoencoderKL
|
||||
from ...models.transformers.transformer_lumina2 import Lumina2Transformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import (
|
||||
deprecate,
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
@@ -48,9 +47,9 @@ EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import Lumina2Pipeline
|
||||
>>> from diffusers import Lumina2Text2ImgPipeline
|
||||
|
||||
>>> pipe = Lumina2Pipeline.from_pretrained("Alpha-VLLM/Lumina-Image-2.0", torch_dtype=torch.bfloat16)
|
||||
>>> pipe = Lumina2Text2ImgPipeline.from_pretrained("Alpha-VLLM/Lumina-Image-2.0", torch_dtype=torch.bfloat16)
|
||||
>>> # Enable memory optimizations.
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
|
||||
@@ -134,7 +133,7 @@ def retrieve_timesteps(
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class Lumina2Pipeline(DiffusionPipeline, Lumina2LoraLoaderMixin):
|
||||
class Lumina2Text2ImgPipeline(DiffusionPipeline, Lumina2LoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Lumina-T2I.
|
||||
|
||||
@@ -768,23 +767,3 @@ class Lumina2Pipeline(DiffusionPipeline, Lumina2LoraLoaderMixin):
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
|
||||
|
||||
class Lumina2Text2ImgPipeline(Lumina2Pipeline):
|
||||
def __init__(
|
||||
self,
|
||||
transformer: Lumina2Transformer2DModel,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: Gemma2PreTrainedModel,
|
||||
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
|
||||
):
|
||||
deprecation_message = "`Lumina2Text2ImgPipeline` has been renamed to `Lumina2Pipeline` and will be removed in a future version. Please use `Lumina2Pipeline` instead."
|
||||
deprecate("diffusers.pipelines.lumina2.pipeline_lumina2.Lumina2Text2ImgPipeline", "0.34", deprecation_message)
|
||||
super().__init__(
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
@@ -1610,7 +1610,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
expected_modules.add(name)
|
||||
optional_parameters.remove(name)
|
||||
|
||||
return sorted(expected_modules), sorted(optional_parameters)
|
||||
return expected_modules, optional_parameters
|
||||
|
||||
@classmethod
|
||||
def _get_signature_types(cls):
|
||||
@@ -1652,12 +1652,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters
|
||||
}
|
||||
|
||||
actual = sorted(set(components.keys()))
|
||||
expected = sorted(expected_modules)
|
||||
if actual != expected:
|
||||
if set(components.keys()) != expected_modules:
|
||||
raise ValueError(
|
||||
f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected"
|
||||
f" {expected} to be defined, but {actual} are defined."
|
||||
f" {expected_modules} to be defined, but {components.keys()} are defined."
|
||||
)
|
||||
|
||||
return components
|
||||
|
||||
@@ -109,30 +109,14 @@ def prompt_clean(text):
|
||||
|
||||
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor,
|
||||
latents_mean: torch.Tensor,
|
||||
latents_std: torch.Tensor,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
sample_mode: str = "sample",
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
encoder_output.latent_dist.mean = (encoder_output.latent_dist.mean - latents_mean) * latents_std
|
||||
encoder_output.latent_dist.logvar = torch.clamp(
|
||||
(encoder_output.latent_dist.logvar - latents_mean) * latents_std, -30.0, 20.0
|
||||
)
|
||||
encoder_output.latent_dist.std = torch.exp(0.5 * encoder_output.latent_dist.logvar)
|
||||
encoder_output.latent_dist.var = torch.exp(encoder_output.latent_dist.logvar)
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
encoder_output.latent_dist.mean = (encoder_output.latent_dist.mean - latents_mean) * latents_std
|
||||
encoder_output.latent_dist.logvar = torch.clamp(
|
||||
(encoder_output.latent_dist.logvar - latents_mean) * latents_std, -30.0, 20.0
|
||||
)
|
||||
encoder_output.latent_dist.std = torch.exp(0.5 * encoder_output.latent_dist.logvar)
|
||||
encoder_output.latent_dist.var = torch.exp(encoder_output.latent_dist.logvar)
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return (encoder_output.latents - latents_mean) * latents_std
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
@@ -401,6 +385,13 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
)
|
||||
video_condition = video_condition.to(device=device, dtype=dtype)
|
||||
|
||||
if isinstance(generator, list):
|
||||
latent_condition = [retrieve_latents(self.vae.encode(video_condition), g) for g in generator]
|
||||
latents = latent_condition = torch.cat(latent_condition)
|
||||
else:
|
||||
latent_condition = retrieve_latents(self.vae.encode(video_condition), generator)
|
||||
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
|
||||
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||
@@ -410,14 +401,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
|
||||
if isinstance(generator, list):
|
||||
latent_condition = [
|
||||
retrieve_latents(self.vae.encode(video_condition), latents_mean, latents_std, g) for g in generator
|
||||
]
|
||||
latent_condition = torch.cat(latent_condition)
|
||||
else:
|
||||
latent_condition = retrieve_latents(self.vae.encode(video_condition), latents_mean, latents_std, generator)
|
||||
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
|
||||
latent_condition = (latent_condition - latents_mean) * latents_std
|
||||
|
||||
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
|
||||
mask_lat_size[:, :, list(range(1, num_frames))] = 0
|
||||
|
||||
@@ -56,14 +56,3 @@ USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version
|
||||
|
||||
if USE_PEFT_BACKEND and _CHECK_PEFT:
|
||||
dep_version_check("peft")
|
||||
|
||||
|
||||
DECODE_ENDPOINT_SD_V1 = "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
DECODE_ENDPOINT_SD_XL = "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
DECODE_ENDPOINT_FLUX = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
DECODE_ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
|
||||
|
||||
ENCODE_ENDPOINT_SD_V1 = "https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
ENCODE_ENDPOINT_SD_XL = "https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
ENCODE_ENDPOINT_FLUX = "https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
|
||||
@@ -362,21 +362,6 @@ class CogView3PlusPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class CogView4ControlPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class CogView4Pipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
@@ -1247,21 +1232,6 @@ class LTXPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Lumina2Pipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Lumina2Text2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
@@ -1277,21 +1247,6 @@ class Lumina2Text2ImgPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LuminaPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LuminaText2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ from types import ModuleType
|
||||
from typing import Any, Union
|
||||
|
||||
from huggingface_hub.utils import is_jinja_available # noqa: F401
|
||||
from packaging import version
|
||||
from packaging.version import Version, parse
|
||||
|
||||
from . import logging
|
||||
@@ -51,30 +52,36 @@ DIFFUSERS_SLOW_IMPORT = DIFFUSERS_SLOW_IMPORT in ENV_VARS_TRUE_VALUES
|
||||
|
||||
STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}
|
||||
|
||||
_is_google_colab = "google.colab" in sys.modules or any(k.startswith("COLAB_") for k in os.environ)
|
||||
|
||||
|
||||
def _is_package_available(pkg_name: str):
|
||||
pkg_exists = importlib.util.find_spec(pkg_name) is not None
|
||||
pkg_version = "N/A"
|
||||
|
||||
if pkg_exists:
|
||||
try:
|
||||
pkg_version = importlib_metadata.version(pkg_name)
|
||||
logger.debug(f"Successfully imported {pkg_name} version {pkg_version}")
|
||||
except (ImportError, importlib_metadata.PackageNotFoundError):
|
||||
pkg_exists = False
|
||||
|
||||
return pkg_exists, pkg_version
|
||||
|
||||
|
||||
_torch_version = "N/A"
|
||||
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
|
||||
_torch_available, _torch_version = _is_package_available("torch")
|
||||
|
||||
_torch_available = importlib.util.find_spec("torch") is not None
|
||||
if _torch_available:
|
||||
try:
|
||||
_torch_version = importlib_metadata.version("torch")
|
||||
logger.info(f"PyTorch version {_torch_version} available.")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_torch_available = False
|
||||
else:
|
||||
logger.info("Disabling PyTorch because USE_TORCH is set")
|
||||
_torch_available = False
|
||||
|
||||
_torch_xla_available = importlib.util.find_spec("torch_xla") is not None
|
||||
if _torch_xla_available:
|
||||
try:
|
||||
_torch_xla_version = importlib_metadata.version("torch_xla")
|
||||
logger.info(f"PyTorch XLA version {_torch_xla_version} available.")
|
||||
except ImportError:
|
||||
_torch_xla_available = False
|
||||
|
||||
# check whether torch_npu is available
|
||||
_torch_npu_available = importlib.util.find_spec("torch_npu") is not None
|
||||
if _torch_npu_available:
|
||||
try:
|
||||
_torch_npu_version = importlib_metadata.version("torch_npu")
|
||||
logger.info(f"torch_npu version {_torch_npu_version} available.")
|
||||
except ImportError:
|
||||
_torch_npu_available = False
|
||||
|
||||
_jax_version = "N/A"
|
||||
_flax_version = "N/A"
|
||||
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
|
||||
@@ -90,12 +97,47 @@ else:
|
||||
_flax_available = False
|
||||
|
||||
if USE_SAFETENSORS in ENV_VARS_TRUE_AND_AUTO_VALUES:
|
||||
_safetensors_available, _safetensors_version = _is_package_available("safetensors")
|
||||
|
||||
_safetensors_available = importlib.util.find_spec("safetensors") is not None
|
||||
if _safetensors_available:
|
||||
try:
|
||||
_safetensors_version = importlib_metadata.version("safetensors")
|
||||
logger.info(f"Safetensors version {_safetensors_version} available.")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_safetensors_available = False
|
||||
else:
|
||||
logger.info("Disabling Safetensors because USE_TF is set")
|
||||
_safetensors_available = False
|
||||
|
||||
_transformers_available = importlib.util.find_spec("transformers") is not None
|
||||
try:
|
||||
_transformers_version = importlib_metadata.version("transformers")
|
||||
logger.debug(f"Successfully imported transformers version {_transformers_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_transformers_available = False
|
||||
|
||||
_hf_hub_available = importlib.util.find_spec("huggingface_hub") is not None
|
||||
try:
|
||||
_hf_hub_version = importlib_metadata.version("huggingface_hub")
|
||||
logger.debug(f"Successfully imported huggingface_hub version {_hf_hub_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_hf_hub_available = False
|
||||
|
||||
|
||||
_inflect_available = importlib.util.find_spec("inflect") is not None
|
||||
try:
|
||||
_inflect_version = importlib_metadata.version("inflect")
|
||||
logger.debug(f"Successfully imported inflect version {_inflect_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_inflect_available = False
|
||||
|
||||
|
||||
_unidecode_available = importlib.util.find_spec("unidecode") is not None
|
||||
try:
|
||||
_unidecode_version = importlib_metadata.version("unidecode")
|
||||
logger.debug(f"Successfully imported unidecode version {_unidecode_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_unidecode_available = False
|
||||
|
||||
_onnxruntime_version = "N/A"
|
||||
_onnx_available = importlib.util.find_spec("onnxruntime") is not None
|
||||
if _onnx_available:
|
||||
@@ -144,6 +186,85 @@ try:
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_opencv_available = False
|
||||
|
||||
_scipy_available = importlib.util.find_spec("scipy") is not None
|
||||
try:
|
||||
_scipy_version = importlib_metadata.version("scipy")
|
||||
logger.debug(f"Successfully imported scipy version {_scipy_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_scipy_available = False
|
||||
|
||||
_librosa_available = importlib.util.find_spec("librosa") is not None
|
||||
try:
|
||||
_librosa_version = importlib_metadata.version("librosa")
|
||||
logger.debug(f"Successfully imported librosa version {_librosa_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_librosa_available = False
|
||||
|
||||
_accelerate_available = importlib.util.find_spec("accelerate") is not None
|
||||
try:
|
||||
_accelerate_version = importlib_metadata.version("accelerate")
|
||||
logger.debug(f"Successfully imported accelerate version {_accelerate_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_accelerate_available = False
|
||||
|
||||
_xformers_available = importlib.util.find_spec("xformers") is not None
|
||||
try:
|
||||
_xformers_version = importlib_metadata.version("xformers")
|
||||
if _torch_available:
|
||||
_torch_version = importlib_metadata.version("torch")
|
||||
if version.Version(_torch_version) < version.Version("1.12"):
|
||||
raise ValueError("xformers is installed in your environment and requires PyTorch >= 1.12")
|
||||
|
||||
logger.debug(f"Successfully imported xformers version {_xformers_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_xformers_available = False
|
||||
|
||||
_k_diffusion_available = importlib.util.find_spec("k_diffusion") is not None
|
||||
try:
|
||||
_k_diffusion_version = importlib_metadata.version("k_diffusion")
|
||||
logger.debug(f"Successfully imported k-diffusion version {_k_diffusion_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_k_diffusion_available = False
|
||||
|
||||
_note_seq_available = importlib.util.find_spec("note_seq") is not None
|
||||
try:
|
||||
_note_seq_version = importlib_metadata.version("note_seq")
|
||||
logger.debug(f"Successfully imported note-seq version {_note_seq_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_note_seq_available = False
|
||||
|
||||
_wandb_available = importlib.util.find_spec("wandb") is not None
|
||||
try:
|
||||
_wandb_version = importlib_metadata.version("wandb")
|
||||
logger.debug(f"Successfully imported wandb version {_wandb_version }")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_wandb_available = False
|
||||
|
||||
|
||||
_tensorboard_available = importlib.util.find_spec("tensorboard")
|
||||
try:
|
||||
_tensorboard_version = importlib_metadata.version("tensorboard")
|
||||
logger.debug(f"Successfully imported tensorboard version {_tensorboard_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_tensorboard_available = False
|
||||
|
||||
|
||||
_compel_available = importlib.util.find_spec("compel")
|
||||
try:
|
||||
_compel_version = importlib_metadata.version("compel")
|
||||
logger.debug(f"Successfully imported compel version {_compel_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_compel_available = False
|
||||
|
||||
|
||||
_ftfy_available = importlib.util.find_spec("ftfy") is not None
|
||||
try:
|
||||
_ftfy_version = importlib_metadata.version("ftfy")
|
||||
logger.debug(f"Successfully imported ftfy version {_ftfy_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_ftfy_available = False
|
||||
|
||||
|
||||
_bs4_available = importlib.util.find_spec("bs4") is not None
|
||||
try:
|
||||
# importlib metadata under different name
|
||||
@@ -152,6 +273,13 @@ try:
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_bs4_available = False
|
||||
|
||||
_torchsde_available = importlib.util.find_spec("torchsde") is not None
|
||||
try:
|
||||
_torchsde_version = importlib_metadata.version("torchsde")
|
||||
logger.debug(f"Successfully imported torchsde version {_torchsde_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_torchsde_available = False
|
||||
|
||||
_invisible_watermark_available = importlib.util.find_spec("imwatermark") is not None
|
||||
try:
|
||||
_invisible_watermark_version = importlib_metadata.version("invisible-watermark")
|
||||
@@ -159,42 +287,91 @@ try:
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_invisible_watermark_available = False
|
||||
|
||||
_torch_xla_available, _torch_xla_version = _is_package_available("torch_xla")
|
||||
_torch_npu_available, _torch_npu_version = _is_package_available("torch_npu")
|
||||
_transformers_available, _transformers_version = _is_package_available("transformers")
|
||||
_hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub")
|
||||
_inflect_available, _inflect_version = _is_package_available("inflect")
|
||||
_unidecode_available, _unidecode_version = _is_package_available("unidecode")
|
||||
_k_diffusion_available, _k_diffusion_version = _is_package_available("k_diffusion")
|
||||
_note_seq_available, _note_seq_version = _is_package_available("note_seq")
|
||||
_wandb_available, _wandb_version = _is_package_available("wandb")
|
||||
_tensorboard_available, _tensorboard_version = _is_package_available("tensorboard")
|
||||
_compel_available, _compel_version = _is_package_available("compel")
|
||||
_sentencepiece_available, _sentencepiece_version = _is_package_available("sentencepiece")
|
||||
_torchsde_available, _torchsde_version = _is_package_available("torchsde")
|
||||
_peft_available, _peft_version = _is_package_available("peft")
|
||||
_torchvision_available, _torchvision_version = _is_package_available("torchvision")
|
||||
_matplotlib_available, _matplotlib_version = _is_package_available("matplotlib")
|
||||
_timm_available, _timm_version = _is_package_available("timm")
|
||||
_bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes")
|
||||
_imageio_available, _imageio_version = _is_package_available("imageio")
|
||||
_ftfy_available, _ftfy_version = _is_package_available("ftfy")
|
||||
_scipy_available, _scipy_version = _is_package_available("scipy")
|
||||
_librosa_available, _librosa_version = _is_package_available("librosa")
|
||||
_accelerate_available, _accelerate_version = _is_package_available("accelerate")
|
||||
_xformers_available, _xformers_version = _is_package_available("xformers")
|
||||
_gguf_available, _gguf_version = _is_package_available("gguf")
|
||||
_torchao_available, _torchao_version = _is_package_available("torchao")
|
||||
_bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes")
|
||||
_torchao_available, _torchao_version = _is_package_available("torchao")
|
||||
|
||||
_optimum_quanto_available = importlib.util.find_spec("optimum") is not None
|
||||
if _optimum_quanto_available:
|
||||
_peft_available = importlib.util.find_spec("peft") is not None
|
||||
try:
|
||||
_peft_version = importlib_metadata.version("peft")
|
||||
logger.debug(f"Successfully imported peft version {_peft_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_peft_available = False
|
||||
|
||||
_torchvision_available = importlib.util.find_spec("torchvision") is not None
|
||||
try:
|
||||
_torchvision_version = importlib_metadata.version("torchvision")
|
||||
logger.debug(f"Successfully imported torchvision version {_torchvision_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_torchvision_available = False
|
||||
|
||||
_sentencepiece_available = importlib.util.find_spec("sentencepiece") is not None
|
||||
try:
|
||||
_sentencepiece_version = importlib_metadata.version("sentencepiece")
|
||||
logger.info(f"Successfully imported sentencepiece version {_sentencepiece_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_sentencepiece_available = False
|
||||
|
||||
_matplotlib_available = importlib.util.find_spec("matplotlib") is not None
|
||||
try:
|
||||
_matplotlib_version = importlib_metadata.version("matplotlib")
|
||||
logger.debug(f"Successfully imported matplotlib version {_matplotlib_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_matplotlib_available = False
|
||||
|
||||
_timm_available = importlib.util.find_spec("timm") is not None
|
||||
if _timm_available:
|
||||
try:
|
||||
_timm_version = importlib_metadata.version("timm")
|
||||
logger.info(f"Timm version {_timm_version} available.")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_timm_available = False
|
||||
|
||||
|
||||
def is_timm_available():
|
||||
return _timm_available
|
||||
|
||||
|
||||
_bitsandbytes_available = importlib.util.find_spec("bitsandbytes") is not None
|
||||
try:
|
||||
_bitsandbytes_version = importlib_metadata.version("bitsandbytes")
|
||||
logger.debug(f"Successfully imported bitsandbytes version {_bitsandbytes_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_bitsandbytes_available = False
|
||||
|
||||
_is_google_colab = "google.colab" in sys.modules or any(k.startswith("COLAB_") for k in os.environ)
|
||||
|
||||
_imageio_available = importlib.util.find_spec("imageio") is not None
|
||||
if _imageio_available:
|
||||
try:
|
||||
_imageio_version = importlib_metadata.version("imageio")
|
||||
logger.debug(f"Successfully imported imageio version {_imageio_version}")
|
||||
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_imageio_available = False
|
||||
|
||||
_is_gguf_available = importlib.util.find_spec("gguf") is not None
|
||||
if _is_gguf_available:
|
||||
try:
|
||||
_gguf_version = importlib_metadata.version("gguf")
|
||||
logger.debug(f"Successfully import gguf version {_gguf_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_is_gguf_available = False
|
||||
|
||||
|
||||
_is_torchao_available = importlib.util.find_spec("torchao") is not None
|
||||
if _is_torchao_available:
|
||||
try:
|
||||
_torchao_version = importlib_metadata.version("torchao")
|
||||
logger.debug(f"Successfully import torchao version {_torchao_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_is_torchao_available = False
|
||||
|
||||
|
||||
_is_optimum_quanto_available = importlib.util.find_spec("optimum") is not None
|
||||
if _is_optimum_quanto_available:
|
||||
try:
|
||||
_optimum_quanto_version = importlib_metadata.version("optimum_quanto")
|
||||
logger.debug(f"Successfully import optimum-quanto version {_optimum_quanto_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_optimum_quanto_available = False
|
||||
_is_optimum_quanto_available = False
|
||||
|
||||
|
||||
def is_torch_available():
|
||||
@@ -318,19 +495,15 @@ def is_imageio_available():
|
||||
|
||||
|
||||
def is_gguf_available():
|
||||
return _gguf_available
|
||||
return _is_gguf_available
|
||||
|
||||
|
||||
def is_torchao_available():
|
||||
return _torchao_available
|
||||
return _is_torchao_available
|
||||
|
||||
|
||||
def is_optimum_quanto_available():
|
||||
return _optimum_quanto_available
|
||||
|
||||
|
||||
def is_timm_available():
|
||||
return _timm_available
|
||||
return _is_optimum_quanto_available
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
@@ -690,7 +863,7 @@ def is_gguf_version(operation: str, version: str):
|
||||
version (`str`):
|
||||
A version string
|
||||
"""
|
||||
if not _gguf_available:
|
||||
if not _is_gguf_available:
|
||||
return False
|
||||
return compare_versions(parse(_gguf_version), operation, version)
|
||||
|
||||
@@ -705,7 +878,7 @@ def is_torchao_version(operation: str, version: str):
|
||||
version (`str`):
|
||||
A version string
|
||||
"""
|
||||
if not _torchao_available:
|
||||
if not _is_torchao_available:
|
||||
return False
|
||||
return compare_versions(parse(_torchao_version), operation, version)
|
||||
|
||||
@@ -735,7 +908,7 @@ def is_optimum_quanto_version(operation: str, version: str):
|
||||
version (`str`):
|
||||
A version string
|
||||
"""
|
||||
if not _optimum_quanto_available:
|
||||
if not _is_optimum_quanto_available:
|
||||
return False
|
||||
return compare_versions(parse(_optimum_quanto_version), operation, version)
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ def detect_image_type(data: bytes) -> str:
|
||||
return "unknown"
|
||||
|
||||
|
||||
def check_inputs_decode(
|
||||
def check_inputs(
|
||||
endpoint: str,
|
||||
tensor: "torch.Tensor",
|
||||
processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
|
||||
@@ -89,7 +89,7 @@ def check_inputs_decode(
|
||||
)
|
||||
|
||||
|
||||
def postprocess_decode(
|
||||
def postprocess(
|
||||
response: requests.Response,
|
||||
processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
|
||||
output_type: Literal["mp4", "pil", "pt"] = "pil",
|
||||
@@ -142,7 +142,7 @@ def postprocess_decode(
|
||||
return output
|
||||
|
||||
|
||||
def prepare_decode(
|
||||
def prepare(
|
||||
tensor: "torch.Tensor",
|
||||
processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
|
||||
do_scaling: bool = True,
|
||||
@@ -293,7 +293,7 @@ def remote_decode(
|
||||
standard_warn=False,
|
||||
)
|
||||
output_tensor_type = "binary"
|
||||
check_inputs_decode(
|
||||
check_inputs(
|
||||
endpoint,
|
||||
tensor,
|
||||
processor,
|
||||
@@ -309,7 +309,7 @@ def remote_decode(
|
||||
height,
|
||||
width,
|
||||
)
|
||||
kwargs = prepare_decode(
|
||||
kwargs = prepare(
|
||||
tensor=tensor,
|
||||
processor=processor,
|
||||
do_scaling=do_scaling,
|
||||
@@ -324,7 +324,7 @@ def remote_decode(
|
||||
response = requests.post(endpoint, **kwargs)
|
||||
if not response.ok:
|
||||
raise RuntimeError(response.json())
|
||||
output = postprocess_decode(
|
||||
output = postprocess(
|
||||
response=response,
|
||||
processor=processor,
|
||||
output_type=output_type,
|
||||
@@ -332,94 +332,3 @@ def remote_decode(
|
||||
partial_postprocess=partial_postprocess,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def check_inputs_encode(
|
||||
endpoint: str,
|
||||
image: Union["torch.Tensor", Image.Image],
|
||||
scaling_factor: Optional[float] = None,
|
||||
shift_factor: Optional[float] = None,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
def postprocess_encode(
|
||||
response: requests.Response,
|
||||
):
|
||||
output_tensor = response.content
|
||||
parameters = response.headers
|
||||
shape = json.loads(parameters["shape"])
|
||||
dtype = parameters["dtype"]
|
||||
torch_dtype = DTYPE_MAP[dtype]
|
||||
output_tensor = torch.frombuffer(bytearray(output_tensor), dtype=torch_dtype).reshape(shape)
|
||||
return output_tensor
|
||||
|
||||
|
||||
def prepare_encode(
|
||||
image: Union["torch.Tensor", Image.Image],
|
||||
scaling_factor: Optional[float] = None,
|
||||
shift_factor: Optional[float] = None,
|
||||
):
|
||||
headers = {}
|
||||
parameters = {}
|
||||
if scaling_factor is not None:
|
||||
parameters["scaling_factor"] = scaling_factor
|
||||
if shift_factor is not None:
|
||||
parameters["shift_factor"] = shift_factor
|
||||
if isinstance(image, torch.Tensor):
|
||||
data = safetensors.torch._tobytes(image, "tensor")
|
||||
parameters["shape"] = list(image.shape)
|
||||
parameters["dtype"] = str(image.dtype).split(".")[-1]
|
||||
else:
|
||||
buffer = io.BytesIO()
|
||||
image.save(buffer, format="PNG")
|
||||
data = buffer.getvalue()
|
||||
return {"data": data, "params": parameters, "headers": headers}
|
||||
|
||||
|
||||
def remote_encode(
|
||||
endpoint: str,
|
||||
image: Union["torch.Tensor", Image.Image],
|
||||
scaling_factor: Optional[float] = None,
|
||||
shift_factor: Optional[float] = None,
|
||||
) -> "torch.Tensor":
|
||||
"""
|
||||
Hugging Face Hybrid Inference that allow running VAE encode remotely.
|
||||
|
||||
Args:
|
||||
endpoint (`str`):
|
||||
Endpoint for Remote Decode.
|
||||
image (`torch.Tensor` or `PIL.Image.Image`):
|
||||
Image to be encoded.
|
||||
scaling_factor (`float`, *optional*):
|
||||
Scaling is applied when passed e.g. [`latents * self.vae.config.scaling_factor`].
|
||||
- SD v1: 0.18215
|
||||
- SD XL: 0.13025
|
||||
- Flux: 0.3611
|
||||
If `None`, input must be passed with scaling applied.
|
||||
shift_factor (`float`, *optional*):
|
||||
Shift is applied when passed e.g. `latents - self.vae.config.shift_factor`.
|
||||
- Flux: 0.1159
|
||||
If `None`, input must be passed with scaling applied.
|
||||
|
||||
Returns:
|
||||
output (`torch.Tensor`).
|
||||
"""
|
||||
check_inputs_encode(
|
||||
endpoint,
|
||||
image,
|
||||
scaling_factor,
|
||||
shift_factor,
|
||||
)
|
||||
kwargs = prepare_encode(
|
||||
image=image,
|
||||
scaling_factor=scaling_factor,
|
||||
shift_factor=shift_factor,
|
||||
)
|
||||
response = requests.post(endpoint, **kwargs)
|
||||
if not response.ok:
|
||||
raise RuntimeError(response.json())
|
||||
output = postprocess_encode(
|
||||
response=response,
|
||||
)
|
||||
return output
|
||||
|
||||
@@ -101,8 +101,6 @@ if is_torch_available():
|
||||
mps_backend_registered = hasattr(torch.backends, "mps")
|
||||
torch_device = "mps" if (mps_backend_registered and torch.backends.mps.is_available()) else torch_device
|
||||
|
||||
from .torch_utils import get_torch_cuda_device_capability
|
||||
|
||||
|
||||
def torch_all_close(a, b, *args, **kwargs):
|
||||
if not is_torch_available():
|
||||
@@ -284,20 +282,6 @@ def require_torch_gpu(test_case):
|
||||
)
|
||||
|
||||
|
||||
def require_torch_cuda_compatibility(expected_compute_capability):
|
||||
def decorator(test_case):
|
||||
if not torch.cuda.is_available():
|
||||
return unittest.skip(test_case)
|
||||
else:
|
||||
current_compute_capability = get_torch_cuda_device_capability()
|
||||
return unittest.skipUnless(
|
||||
float(current_compute_capability) == float(expected_compute_capability),
|
||||
"Test not supported for this compute capability.",
|
||||
)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# These decorators are for accelerator-specific behaviours that are not GPU-specific
|
||||
def require_torch_accelerator(test_case):
|
||||
"""Decorator marking a test that requires an accelerator backend and PyTorch."""
|
||||
|
||||
+2
-2
@@ -1961,7 +1961,7 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)}
|
||||
logger = logging.get_logger("diffusers.loaders.peft")
|
||||
logger.setLevel(logging.WARNING)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.load_lora_weights(no_op_state_dict)
|
||||
@@ -1981,7 +1981,7 @@ class PeftLoraLoaderMixinTests:
|
||||
prefix = "text_encoder_2"
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_base")
|
||||
logger.setLevel(logging.WARNING)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
self.pipeline_class.load_lora_into_text_encoder(
|
||||
|
||||
@@ -5,13 +5,7 @@ import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, GemmaConfig, GemmaForCausalLM
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
LuminaNextDiT2DModel,
|
||||
LuminaPipeline,
|
||||
LuminaText2ImgPipeline,
|
||||
)
|
||||
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, LuminaNextDiT2DModel, LuminaText2ImgPipeline
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
numpy_cosine_similarity_distance,
|
||||
@@ -23,8 +17,8 @@ from diffusers.utils.testing_utils import (
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
class LuminaPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
pipeline_class = LuminaPipeline
|
||||
class LuminaText2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
pipeline_class = LuminaText2ImgPipeline
|
||||
params = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
@@ -105,17 +99,11 @@ class LuminaPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
def test_xformers_attention_forwardGenerator_pass(self):
|
||||
pass
|
||||
|
||||
def test_deprecation_raises_warning(self):
|
||||
with self.assertWarns(FutureWarning) as warning:
|
||||
_ = LuminaText2ImgPipeline(**self.get_dummy_components()).to(torch_device)
|
||||
warning_message = str(warning.warnings[0].message)
|
||||
assert "renamed to `LuminaPipeline`" in warning_message
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_accelerator
|
||||
class LuminaPipelineSlowTests(unittest.TestCase):
|
||||
pipeline_class = LuminaPipeline
|
||||
class LuminaText2ImgPipelineSlowTests(unittest.TestCase):
|
||||
pipeline_class = LuminaText2ImgPipeline
|
||||
repo_id = "Alpha-VLLM/Lumina-Next-SFT-diffusers"
|
||||
|
||||
def setUp(self):
|
||||
|
||||
@@ -6,17 +6,15 @@ from transformers import AutoTokenizer, Gemma2Config, Gemma2Model
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
Lumina2Pipeline,
|
||||
Lumina2Text2ImgPipeline,
|
||||
Lumina2Transformer2DModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import torch_device
|
||||
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
class Lumina2PipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
pipeline_class = Lumina2Pipeline
|
||||
class Lumina2Text2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
pipeline_class = Lumina2Text2ImgPipeline
|
||||
params = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
@@ -117,9 +115,3 @@ class Lumina2PipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_deprecation_raises_warning(self):
|
||||
with self.assertWarns(FutureWarning) as warning:
|
||||
_ = Lumina2Text2ImgPipeline(**self.get_dummy_components()).to(torch_device)
|
||||
warning_message = str(warning.warnings[0].message)
|
||||
assert "renamed to `Lumina2Pipeline`" in warning_message
|
||||
|
||||
@@ -19,7 +19,7 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible, variant_compatible_siblings
|
||||
from diffusers.utils.testing_utils import require_torch_gpu, torch_device
|
||||
from diffusers.utils.testing_utils import torch_device
|
||||
|
||||
|
||||
class IsSafetensorsCompatibleTests(unittest.TestCase):
|
||||
@@ -826,104 +826,3 @@ class ProgressBarTests(unittest.TestCase):
|
||||
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
|
||||
_ = pipe(**inputs)
|
||||
self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
class PipelineDeviceAndDtypeStabilityTests(unittest.TestCase):
|
||||
expected_pipe_device = torch.device("cuda:0")
|
||||
expected_pipe_dtype = torch.float64
|
||||
|
||||
def get_dummy_components_image_generation(self):
|
||||
cross_attention_dim = 8
|
||||
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(4, 8),
|
||||
layers_per_block=1,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=2,
|
||||
)
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[4, 8],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
norm_num_groups=2,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=cross_attention_dim,
|
||||
intermediate_size=16,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=2,
|
||||
num_hidden_layers=2,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
"image_encoder": None,
|
||||
}
|
||||
return components
|
||||
|
||||
def test_deterministic_device(self):
|
||||
components = self.get_dummy_components_image_generation()
|
||||
|
||||
pipe = StableDiffusionPipeline(**components)
|
||||
pipe.to(device=torch_device, dtype=torch.float32)
|
||||
|
||||
pipe.unet.to(device="cpu")
|
||||
pipe.vae.to(device="cuda")
|
||||
pipe.text_encoder.to(device="cuda:0")
|
||||
|
||||
pipe_device = pipe.device
|
||||
|
||||
self.assertEqual(
|
||||
self.expected_pipe_device,
|
||||
pipe_device,
|
||||
f"Wrong expected device. Expected {self.expected_pipe_device}. Got {pipe_device}.",
|
||||
)
|
||||
|
||||
def test_deterministic_dtype(self):
|
||||
components = self.get_dummy_components_image_generation()
|
||||
|
||||
pipe = StableDiffusionPipeline(**components)
|
||||
pipe.to(device=torch_device, dtype=torch.float32)
|
||||
|
||||
pipe.unet.to(dtype=torch.float16)
|
||||
pipe.vae.to(dtype=torch.float32)
|
||||
pipe.text_encoder.to(dtype=torch.float64)
|
||||
|
||||
pipe_dtype = pipe.dtype
|
||||
|
||||
self.assertEqual(
|
||||
self.expected_pipe_dtype,
|
||||
pipe_dtype,
|
||||
f"Wrong expected dtype. Expected {self.expected_pipe_dtype}. Got {pipe_dtype}.",
|
||||
)
|
||||
|
||||
@@ -33,7 +33,6 @@ from diffusers.utils.testing_utils import (
|
||||
numpy_cosine_similarity_distance,
|
||||
require_accelerate,
|
||||
require_bitsandbytes_version_greater,
|
||||
require_peft_backend,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_transformers_version_greater,
|
||||
@@ -669,7 +668,6 @@ class SlowBnb4BitFluxTests(Base4bitTests):
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
|
||||
self.assertTrue(max_diff < 1e-3)
|
||||
|
||||
@require_peft_backend
|
||||
def test_lora_loading(self):
|
||||
self.pipeline_4bit.load_lora_weights(
|
||||
hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd"
|
||||
|
||||
@@ -10,7 +10,6 @@ from diffusers.utils.testing_utils import (
|
||||
numpy_cosine_similarity_distance,
|
||||
require_accelerate,
|
||||
require_big_gpu_with_torch_cuda,
|
||||
require_torch_cuda_compatibility,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
@@ -312,7 +311,6 @@ class FluxTransformerInt8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCa
|
||||
return {"weights_dtype": "int8"}
|
||||
|
||||
|
||||
@require_torch_cuda_compatibility(8.0)
|
||||
class FluxTransformerInt4WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
|
||||
expected_memory_reduction = 0.55
|
||||
|
||||
@@ -320,7 +318,6 @@ class FluxTransformerInt4WeightsTest(FluxTransformerQuantoMixin, unittest.TestCa
|
||||
return {"weights_dtype": "int4"}
|
||||
|
||||
|
||||
@require_torch_cuda_compatibility(8.0)
|
||||
class FluxTransformerInt2WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
|
||||
expected_memory_reduction = 0.65
|
||||
|
||||
|
||||
@@ -21,15 +21,7 @@ import PIL.Image
|
||||
import torch
|
||||
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.utils.constants import (
|
||||
DECODE_ENDPOINT_FLUX,
|
||||
DECODE_ENDPOINT_HUNYUAN_VIDEO,
|
||||
DECODE_ENDPOINT_SD_V1,
|
||||
DECODE_ENDPOINT_SD_XL,
|
||||
)
|
||||
from diffusers.utils.remote_utils import (
|
||||
remote_decode,
|
||||
)
|
||||
from diffusers.utils.remote_utils import remote_decode
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
slow,
|
||||
@@ -41,6 +33,11 @@ from diffusers.video_processor import VideoProcessor
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
ENDPOINT_SD_V1 = "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
ENDPOINT_SD_XL = "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
ENDPOINT_FLUX = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
|
||||
|
||||
class RemoteAutoencoderKLMixin:
|
||||
shape: Tuple[int, ...] = None
|
||||
@@ -353,7 +350,7 @@ class RemoteAutoencoderKLSDv1Tests(
|
||||
512,
|
||||
512,
|
||||
)
|
||||
endpoint = DECODE_ENDPOINT_SD_V1
|
||||
endpoint = ENDPOINT_SD_V1
|
||||
dtype = torch.float16
|
||||
scaling_factor = 0.18215
|
||||
shift_factor = None
|
||||
@@ -377,7 +374,7 @@ class RemoteAutoencoderKLSDXLTests(
|
||||
1024,
|
||||
1024,
|
||||
)
|
||||
endpoint = DECODE_ENDPOINT_SD_XL
|
||||
endpoint = ENDPOINT_SD_XL
|
||||
dtype = torch.float16
|
||||
scaling_factor = 0.13025
|
||||
shift_factor = None
|
||||
@@ -401,7 +398,7 @@ class RemoteAutoencoderKLFluxTests(
|
||||
1024,
|
||||
1024,
|
||||
)
|
||||
endpoint = DECODE_ENDPOINT_FLUX
|
||||
endpoint = ENDPOINT_FLUX
|
||||
dtype = torch.bfloat16
|
||||
scaling_factor = 0.3611
|
||||
shift_factor = 0.1159
|
||||
@@ -428,7 +425,7 @@ class RemoteAutoencoderKLFluxPackedTests(
|
||||
)
|
||||
height = 1024
|
||||
width = 1024
|
||||
endpoint = DECODE_ENDPOINT_FLUX
|
||||
endpoint = ENDPOINT_FLUX
|
||||
dtype = torch.bfloat16
|
||||
scaling_factor = 0.3611
|
||||
shift_factor = 0.1159
|
||||
@@ -456,7 +453,7 @@ class RemoteAutoencoderKLHunyuanVideoTests(
|
||||
320,
|
||||
512,
|
||||
)
|
||||
endpoint = DECODE_ENDPOINT_HUNYUAN_VIDEO
|
||||
endpoint = ENDPOINT_HUNYUAN_VIDEO
|
||||
dtype = torch.float16
|
||||
scaling_factor = 0.476986
|
||||
processor_cls = VideoProcessor
|
||||
@@ -507,7 +504,7 @@ class RemoteAutoencoderKLSDv1SlowTests(
|
||||
RemoteAutoencoderKLSlowTestMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
endpoint = DECODE_ENDPOINT_SD_V1
|
||||
endpoint = ENDPOINT_SD_V1
|
||||
dtype = torch.float16
|
||||
scaling_factor = 0.18215
|
||||
shift_factor = None
|
||||
@@ -518,7 +515,7 @@ class RemoteAutoencoderKLSDXLSlowTests(
|
||||
RemoteAutoencoderKLSlowTestMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
endpoint = DECODE_ENDPOINT_SD_XL
|
||||
endpoint = ENDPOINT_SD_XL
|
||||
dtype = torch.float16
|
||||
scaling_factor = 0.13025
|
||||
shift_factor = None
|
||||
@@ -530,7 +527,7 @@ class RemoteAutoencoderKLFluxSlowTests(
|
||||
unittest.TestCase,
|
||||
):
|
||||
channels = 16
|
||||
endpoint = DECODE_ENDPOINT_FLUX
|
||||
endpoint = ENDPOINT_FLUX
|
||||
dtype = torch.bfloat16
|
||||
scaling_factor = 0.3611
|
||||
shift_factor = 0.1159
|
||||
|
||||
@@ -1,224 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.constants import (
|
||||
DECODE_ENDPOINT_FLUX,
|
||||
DECODE_ENDPOINT_SD_V1,
|
||||
DECODE_ENDPOINT_SD_XL,
|
||||
ENCODE_ENDPOINT_FLUX,
|
||||
ENCODE_ENDPOINT_SD_V1,
|
||||
ENCODE_ENDPOINT_SD_XL,
|
||||
)
|
||||
from diffusers.utils.remote_utils import (
|
||||
remote_decode,
|
||||
remote_encode,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
slow,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
IMAGE = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg?download=true"
|
||||
|
||||
|
||||
class RemoteAutoencoderKLEncodeMixin:
|
||||
channels: int = None
|
||||
endpoint: str = None
|
||||
decode_endpoint: str = None
|
||||
dtype: torch.dtype = None
|
||||
scaling_factor: float = None
|
||||
shift_factor: float = None
|
||||
image: PIL.Image.Image = None
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
if self.image is None:
|
||||
self.image = load_image(IMAGE)
|
||||
inputs = {
|
||||
"endpoint": self.endpoint,
|
||||
"image": self.image,
|
||||
"scaling_factor": self.scaling_factor,
|
||||
"shift_factor": self.shift_factor,
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_image_input(self):
|
||||
inputs = self.get_dummy_inputs()
|
||||
height, width = inputs["image"].height, inputs["image"].width
|
||||
output = remote_encode(**inputs)
|
||||
self.assertEqual(list(output.shape), [1, self.channels, height // 8, width // 8])
|
||||
decoded = remote_decode(
|
||||
tensor=output,
|
||||
endpoint=self.decode_endpoint,
|
||||
scaling_factor=self.scaling_factor,
|
||||
shift_factor=self.shift_factor,
|
||||
image_format="png",
|
||||
)
|
||||
self.assertEqual(decoded.height, height)
|
||||
self.assertEqual(decoded.width, width)
|
||||
# image_slice = torch.from_numpy(np.array(inputs["image"])[0, -3:, -3:].flatten())
|
||||
# decoded_slice = torch.from_numpy(np.array(decoded)[0, -3:, -3:].flatten())
|
||||
# TODO: how to test this? encode->decode is lossy. expected slice of encoded latent?
|
||||
|
||||
|
||||
class RemoteAutoencoderKLSDv1Tests(
|
||||
RemoteAutoencoderKLEncodeMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
channels = 4
|
||||
endpoint = ENCODE_ENDPOINT_SD_V1
|
||||
decode_endpoint = DECODE_ENDPOINT_SD_V1
|
||||
dtype = torch.float16
|
||||
scaling_factor = 0.18215
|
||||
shift_factor = None
|
||||
|
||||
|
||||
class RemoteAutoencoderKLSDXLTests(
|
||||
RemoteAutoencoderKLEncodeMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
channels = 4
|
||||
endpoint = ENCODE_ENDPOINT_SD_XL
|
||||
decode_endpoint = DECODE_ENDPOINT_SD_XL
|
||||
dtype = torch.float16
|
||||
scaling_factor = 0.13025
|
||||
shift_factor = None
|
||||
|
||||
|
||||
class RemoteAutoencoderKLFluxTests(
|
||||
RemoteAutoencoderKLEncodeMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
channels = 16
|
||||
endpoint = ENCODE_ENDPOINT_FLUX
|
||||
decode_endpoint = DECODE_ENDPOINT_FLUX
|
||||
dtype = torch.bfloat16
|
||||
scaling_factor = 0.3611
|
||||
shift_factor = 0.1159
|
||||
|
||||
|
||||
class RemoteAutoencoderKLEncodeSlowTestMixin:
|
||||
channels: int = 4
|
||||
endpoint: str = None
|
||||
decode_endpoint: str = None
|
||||
dtype: torch.dtype = None
|
||||
scaling_factor: float = None
|
||||
shift_factor: float = None
|
||||
image: PIL.Image.Image = None
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
if self.image is None:
|
||||
self.image = load_image(IMAGE)
|
||||
inputs = {
|
||||
"endpoint": self.endpoint,
|
||||
"image": self.image,
|
||||
"scaling_factor": self.scaling_factor,
|
||||
"shift_factor": self.shift_factor,
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_multi_res(self):
|
||||
inputs = self.get_dummy_inputs()
|
||||
for height in {
|
||||
320,
|
||||
512,
|
||||
640,
|
||||
704,
|
||||
896,
|
||||
1024,
|
||||
1208,
|
||||
1384,
|
||||
1536,
|
||||
1608,
|
||||
1864,
|
||||
2048,
|
||||
}:
|
||||
for width in {
|
||||
320,
|
||||
512,
|
||||
640,
|
||||
704,
|
||||
896,
|
||||
1024,
|
||||
1208,
|
||||
1384,
|
||||
1536,
|
||||
1608,
|
||||
1864,
|
||||
2048,
|
||||
}:
|
||||
inputs["image"] = inputs["image"].resize(
|
||||
(
|
||||
width,
|
||||
height,
|
||||
)
|
||||
)
|
||||
output = remote_encode(**inputs)
|
||||
self.assertEqual(list(output.shape), [1, self.channels, height // 8, width // 8])
|
||||
decoded = remote_decode(
|
||||
tensor=output,
|
||||
endpoint=self.decode_endpoint,
|
||||
scaling_factor=self.scaling_factor,
|
||||
shift_factor=self.shift_factor,
|
||||
image_format="png",
|
||||
)
|
||||
self.assertEqual(decoded.height, height)
|
||||
self.assertEqual(decoded.width, width)
|
||||
decoded.save(f"test_multi_res_{height}_{width}.png")
|
||||
|
||||
|
||||
@slow
|
||||
class RemoteAutoencoderKLSDv1SlowTests(
|
||||
RemoteAutoencoderKLEncodeSlowTestMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
endpoint = ENCODE_ENDPOINT_SD_V1
|
||||
decode_endpoint = DECODE_ENDPOINT_SD_V1
|
||||
dtype = torch.float16
|
||||
scaling_factor = 0.18215
|
||||
shift_factor = None
|
||||
|
||||
|
||||
@slow
|
||||
class RemoteAutoencoderKLSDXLSlowTests(
|
||||
RemoteAutoencoderKLEncodeSlowTestMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
endpoint = ENCODE_ENDPOINT_SD_XL
|
||||
decode_endpoint = DECODE_ENDPOINT_SD_XL
|
||||
dtype = torch.float16
|
||||
scaling_factor = 0.13025
|
||||
shift_factor = None
|
||||
|
||||
|
||||
@slow
|
||||
class RemoteAutoencoderKLFluxSlowTests(
|
||||
RemoteAutoencoderKLEncodeSlowTestMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
channels = 16
|
||||
endpoint = ENCODE_ENDPOINT_FLUX
|
||||
decode_endpoint = DECODE_ENDPOINT_FLUX
|
||||
dtype = torch.bfloat16
|
||||
scaling_factor = 0.3611
|
||||
shift_factor = 0.1159
|
||||
Reference in New Issue
Block a user