Compare commits
51 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6a94ef7388 | |||
| f20e4afbaa | |||
| 7787ec11c8 | |||
| c5c7588648 | |||
| 6ccaed77ed | |||
| f6ece89c6d | |||
| 542a6034d3 | |||
| e95ac9d82f | |||
| 104e1636b2 | |||
| 373106cedb | |||
| 8ceed7d3ae | |||
| 9836f0e000 | |||
| 20379d9d13 | |||
| 3a6caba8e4 | |||
| 4267d8f4eb | |||
| f4fa3beee7 | |||
| 7e3353196c | |||
| 8c249d1401 | |||
| b555a03723 | |||
| 06fee551e9 | |||
| 8b99f7e157 | |||
| 07dd6f8c0e | |||
| f8d4a1e283 | |||
| ddd0cfb497 | |||
| 4f438de35a | |||
| 98cc6d05e4 | |||
| c3726153fd | |||
| e48f6aeeb4 | |||
| 01abfc8736 | |||
| 92fe689f06 | |||
| 0ba1f76d4d | |||
| d6bf268a4a | |||
| 3c0a0129fe | |||
| 2d380895e5 | |||
| 0c47c954f3 | |||
| 7acf8345f6 | |||
| 599c887164 | |||
| 393aefcdc7 | |||
| 6674a5157f | |||
| 784db0eaab | |||
| 66e50d4e24 | |||
| c5c34a4591 | |||
| 87e508f11f | |||
| 53bd367b03 | |||
| 7b904941bc | |||
| fb29132b98 | |||
| 79371661d1 | |||
| 8c661ea586 | |||
| d7ffe60166 | |||
| 10bee525e7 | |||
| d88ae1f52a |
@@ -142,6 +142,7 @@ jobs:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
RUN_COMPILE: yes
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
@@ -525,6 +526,60 @@ jobs:
|
||||
pip install slack_sdk tabulate
|
||||
python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
run_nightly_pipeline_level_quantization_tests:
|
||||
name: Torch quantization nightly tests
|
||||
strategy:
|
||||
fail-fast: false
|
||||
max-parallel: 2
|
||||
runs-on:
|
||||
group: aws-g6e-xlarge-plus
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --shm-size "20gb" --ipc host --gpus 0
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: NVIDIA-SMI
|
||||
run: nvidia-smi
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install -U bitsandbytes optimum_quanto
|
||||
python -m uv pip install pytest-reportlog
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
- name: Pipeline-level quantization tests on GPU
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
BIG_GPU_MEMORY: 40
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
--make-reports=tests_pipeline_level_quant_torch_cuda \
|
||||
--report-log=tests_pipeline_level_quant_torch_cuda.log \
|
||||
tests/quantization/test_pipeline_level_quantization.py
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_pipeline_level_quant_torch_cuda_stats.txt
|
||||
cat reports/tests_pipeline_level_quant_torch_cuda_failures_short.txt
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: torch_cuda_pipeline_level_quant_reports
|
||||
path: reports
|
||||
- name: Generate Report and Notify Channel
|
||||
if: always()
|
||||
run: |
|
||||
pip install slack_sdk tabulate
|
||||
python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
# M1 runner currently not well supported
|
||||
# TODO: (Dhruv) add these back when we setup better testing for Apple Silicon
|
||||
# run_nightly_tests_apple_m1:
|
||||
|
||||
@@ -295,6 +295,8 @@
|
||||
title: CogView4Transformer2DModel
|
||||
- local: api/models/consisid_transformer3d
|
||||
title: ConsisIDTransformer3DModel
|
||||
- local: api/models/cosmos_transformer3d
|
||||
title: CosmosTransformer3DModel
|
||||
- local: api/models/dit_transformer2d
|
||||
title: DiTTransformer2DModel
|
||||
- local: api/models/easyanimate_transformer3d
|
||||
@@ -363,6 +365,8 @@
|
||||
title: AutoencoderKLAllegro
|
||||
- local: api/models/autoencoderkl_cogvideox
|
||||
title: AutoencoderKLCogVideoX
|
||||
- local: api/models/autoencoderkl_cosmos
|
||||
title: AutoencoderKLCosmos
|
||||
- local: api/models/autoencoder_kl_hunyuan_video
|
||||
title: AutoencoderKLHunyuanVideo
|
||||
- local: api/models/autoencoderkl_ltx_video
|
||||
@@ -433,6 +437,8 @@
|
||||
title: ControlNet-XS with Stable Diffusion XL
|
||||
- local: api/pipelines/controlnet_union
|
||||
title: ControlNetUnion
|
||||
- local: api/pipelines/cosmos
|
||||
title: Cosmos
|
||||
- local: api/pipelines/dance_diffusion
|
||||
title: Dance Diffusion
|
||||
- local: api/pipelines/ddim
|
||||
@@ -451,6 +457,8 @@
|
||||
title: Flux
|
||||
- local: api/pipelines/control_flux_inpaint
|
||||
title: FluxControlInpaint
|
||||
- local: api/pipelines/framepack
|
||||
title: Framepack
|
||||
- local: api/pipelines/hidream
|
||||
title: HiDream-I1
|
||||
- local: api/pipelines/hunyuandit
|
||||
@@ -567,6 +575,8 @@
|
||||
title: UniDiffuser
|
||||
- local: api/pipelines/value_guided_sampling
|
||||
title: Value-guided sampling
|
||||
- local: api/pipelines/visualcloze
|
||||
title: VisualCloze
|
||||
- local: api/pipelines/wan
|
||||
title: Wan
|
||||
- local: api/pipelines/wuerstchen
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# AutoencoderKLCosmos
|
||||
|
||||
[Cosmos Tokenizers](https://github.com/NVIDIA/Cosmos-Tokenizer).
|
||||
|
||||
Supported models:
|
||||
- [nvidia/Cosmos-1.0-Tokenizer-CV8x8x8](https://huggingface.co/nvidia/Cosmos-1.0-Tokenizer-CV8x8x8)
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import AutoencoderKLCosmos
|
||||
|
||||
vae = AutoencoderKLCosmos.from_pretrained("nvidia/Cosmos-1.0-Tokenizer-CV8x8x8", subfolder="vae")
|
||||
```
|
||||
|
||||
## AutoencoderKLCosmos
|
||||
|
||||
[[autodoc]] AutoencoderKLCosmos
|
||||
- decode
|
||||
- encode
|
||||
- all
|
||||
|
||||
## AutoencoderKLOutput
|
||||
|
||||
[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
|
||||
|
||||
## DecoderOutput
|
||||
|
||||
[[autodoc]] models.autoencoders.vae.DecoderOutput
|
||||
@@ -0,0 +1,30 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# CosmosTransformer3DModel
|
||||
|
||||
A Diffusion Transformer model for 3D video-like data was introduced in [Cosmos World Foundation Model Platform for Physical AI](https://huggingface.co/papers/2501.03575) by NVIDIA.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import CosmosTransformer3DModel
|
||||
|
||||
transformer = CosmosTransformer3DModel.from_pretrained("nvidia/Cosmos-1.0-Diffusion-7B-Text2World", subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## CosmosTransformer3DModel
|
||||
|
||||
[[autodoc]] CosmosTransformer3DModel
|
||||
|
||||
## Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
|
||||
@@ -21,6 +21,22 @@ from diffusers import HiDreamImageTransformer2DModel
|
||||
transformer = HiDreamImageTransformer2DModel.from_pretrained("HiDream-ai/HiDream-I1-Full", subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## Loading GGUF quantized checkpoints for HiDream-I1
|
||||
|
||||
GGUF checkpoints for the `HiDreamImageTransformer2DModel` can be loaded using `~FromOriginalModelMixin.from_single_file`
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import GGUFQuantizationConfig, HiDreamImageTransformer2DModel
|
||||
|
||||
ckpt_path = "https://huggingface.co/city96/HiDream-I1-Dev-gguf/blob/main/hidream-i1-dev-Q2_K.gguf"
|
||||
transformer = HiDreamImageTransformer2DModel.from_single_file(
|
||||
ckpt_path,
|
||||
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
```
|
||||
|
||||
## HiDreamImageTransformer2DModel
|
||||
|
||||
[[autodoc]] HiDreamImageTransformer2DModel
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# Cosmos
|
||||
|
||||
[Cosmos World Foundation Model Platform for Physical AI](https://huggingface.co/papers/2501.03575) by NVIDIA.
|
||||
|
||||
*Physical AI needs to be trained digitally first. It needs a digital twin of itself, the policy model, and a digital twin of the world, the world model. In this paper, we present the Cosmos World Foundation Model Platform to help developers build customized world models for their Physical AI setups. We position a world foundation model as a general-purpose world model that can be fine-tuned into customized world models for downstream applications. Our platform covers a video curation pipeline, pre-trained world foundation models, examples of post-training of pre-trained world foundation models, and video tokenizers. To help Physical AI builders solve the most critical problems of our society, we make our platform open-source and our models open-weight with permissive licenses available via https://github.com/NVIDIA/Cosmos.*
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## CosmosTextToWorldPipeline
|
||||
|
||||
[[autodoc]] CosmosTextToWorldPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## CosmosVideoToWorldPipeline
|
||||
|
||||
[[autodoc]] CosmosVideoToWorldPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## CosmosPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.cosmos.pipeline_output.CosmosPipelineOutput
|
||||
@@ -0,0 +1,209 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# Framepack
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
</div>
|
||||
|
||||
[Packing Input Frame Context in Next-Frame Prediction Models for Video Generation](https://arxiv.org/abs/2504.12626) by Lvmin Zhang and Maneesh Agrawala.
|
||||
|
||||
*We present a neural network structure, FramePack, to train next-frame (or next-frame-section) prediction models for video generation. The FramePack compresses input frames to make the transformer context length a fixed number regardless of the video length. As a result, we are able to process a large number of frames using video diffusion with computation bottleneck similar to image diffusion. This also makes the training video batch sizes significantly higher (batch sizes become comparable to image diffusion training). We also propose an anti-drifting sampling method that generates frames in inverted temporal order with early-established endpoints to avoid exposure bias (error accumulation over iterations). Finally, we show that existing video diffusion models can be finetuned with FramePack, and their visual quality may be improved because the next-frame prediction supports more balanced diffusion schedulers with less extreme flow shift timesteps.*
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Available models
|
||||
|
||||
| Model name | Description |
|
||||
|:---|:---|
|
||||
- [`lllyasviel/FramePackI2V_HY`](https://huggingface.co/lllyasviel/FramePackI2V_HY) | Trained with the "inverted anti-drifting" strategy as described in the paper. Inference requires setting `sampling_type="inverted_anti_drifting"` when running the pipeline. |
|
||||
- [`lllyasviel/FramePack_F1_I2V_HY_20250503`](https://huggingface.co/lllyasviel/FramePack_F1_I2V_HY_20250503) | Trained with a novel anti-drifting strategy but inference is performed in "vanilla" strategy as described in the paper. Inference requires setting `sampling_type="vanilla"` when running the pipeline. |
|
||||
|
||||
## Usage
|
||||
|
||||
Refer to the pipeline documentation for basic usage examples. The following section contains examples of offloading, different sampling methods, quantization, and more.
|
||||
|
||||
### First and last frame to video
|
||||
|
||||
The following example shows how to use Framepack with start and end image controls, using the inverted anti-drifiting sampling model.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import HunyuanVideoFramepackPipeline, HunyuanVideoFramepackTransformer3DModel
|
||||
from diffusers.utils import export_to_video, load_image
|
||||
from transformers import SiglipImageProcessor, SiglipVisionModel
|
||||
|
||||
transformer = HunyuanVideoFramepackTransformer3DModel.from_pretrained(
|
||||
"lllyasviel/FramePackI2V_HY", torch_dtype=torch.bfloat16
|
||||
)
|
||||
feature_extractor = SiglipImageProcessor.from_pretrained(
|
||||
"lllyasviel/flux_redux_bfl", subfolder="feature_extractor"
|
||||
)
|
||||
image_encoder = SiglipVisionModel.from_pretrained(
|
||||
"lllyasviel/flux_redux_bfl", subfolder="image_encoder", torch_dtype=torch.float16
|
||||
)
|
||||
pipe = HunyuanVideoFramepackPipeline.from_pretrained(
|
||||
"hunyuanvideo-community/HunyuanVideo",
|
||||
transformer=transformer,
|
||||
feature_extractor=feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
# Enable memory optimizations
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.vae.enable_tiling()
|
||||
|
||||
prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
|
||||
first_image = load_image(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png"
|
||||
)
|
||||
last_image = load_image(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png"
|
||||
)
|
||||
output = pipe(
|
||||
image=first_image,
|
||||
last_image=last_image,
|
||||
prompt=prompt,
|
||||
height=512,
|
||||
width=512,
|
||||
num_frames=91,
|
||||
num_inference_steps=30,
|
||||
guidance_scale=9.0,
|
||||
generator=torch.Generator().manual_seed(0),
|
||||
sampling_type="inverted_anti_drifting",
|
||||
).frames[0]
|
||||
export_to_video(output, "output.mp4", fps=30)
|
||||
```
|
||||
|
||||
### Vanilla sampling
|
||||
|
||||
The following example shows how to use Framepack with the F1 model trained with vanilla sampling but new regulation approach for anti-drifting.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import HunyuanVideoFramepackPipeline, HunyuanVideoFramepackTransformer3DModel
|
||||
from diffusers.utils import export_to_video, load_image
|
||||
from transformers import SiglipImageProcessor, SiglipVisionModel
|
||||
|
||||
transformer = HunyuanVideoFramepackTransformer3DModel.from_pretrained(
|
||||
"lllyasviel/FramePack_F1_I2V_HY_20250503", torch_dtype=torch.bfloat16
|
||||
)
|
||||
feature_extractor = SiglipImageProcessor.from_pretrained(
|
||||
"lllyasviel/flux_redux_bfl", subfolder="feature_extractor"
|
||||
)
|
||||
image_encoder = SiglipVisionModel.from_pretrained(
|
||||
"lllyasviel/flux_redux_bfl", subfolder="image_encoder", torch_dtype=torch.float16
|
||||
)
|
||||
pipe = HunyuanVideoFramepackPipeline.from_pretrained(
|
||||
"hunyuanvideo-community/HunyuanVideo",
|
||||
transformer=transformer,
|
||||
feature_extractor=feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
# Enable memory optimizations
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.vae.enable_tiling()
|
||||
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/penguin.png"
|
||||
)
|
||||
output = pipe(
|
||||
image=image,
|
||||
prompt="A penguin dancing in the snow",
|
||||
height=832,
|
||||
width=480,
|
||||
num_frames=91,
|
||||
num_inference_steps=30,
|
||||
guidance_scale=9.0,
|
||||
generator=torch.Generator().manual_seed(0),
|
||||
sampling_type="vanilla",
|
||||
).frames[0]
|
||||
export_to_video(output, "output.mp4", fps=30)
|
||||
```
|
||||
|
||||
### Group offloading
|
||||
|
||||
Group offloading ([`~hooks.apply_group_offloading`]) provides aggressive memory optimizations for offloading internal parts of any model to the CPU, with possibly no additional overhead to generation time. If you have very low VRAM available, this approach may be suitable for you depending on the amount of CPU RAM available.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import HunyuanVideoFramepackPipeline, HunyuanVideoFramepackTransformer3DModel
|
||||
from diffusers.hooks import apply_group_offloading
|
||||
from diffusers.utils import export_to_video, load_image
|
||||
from transformers import SiglipImageProcessor, SiglipVisionModel
|
||||
|
||||
transformer = HunyuanVideoFramepackTransformer3DModel.from_pretrained(
|
||||
"lllyasviel/FramePack_F1_I2V_HY_20250503", torch_dtype=torch.bfloat16
|
||||
)
|
||||
feature_extractor = SiglipImageProcessor.from_pretrained(
|
||||
"lllyasviel/flux_redux_bfl", subfolder="feature_extractor"
|
||||
)
|
||||
image_encoder = SiglipVisionModel.from_pretrained(
|
||||
"lllyasviel/flux_redux_bfl", subfolder="image_encoder", torch_dtype=torch.float16
|
||||
)
|
||||
pipe = HunyuanVideoFramepackPipeline.from_pretrained(
|
||||
"hunyuanvideo-community/HunyuanVideo",
|
||||
transformer=transformer,
|
||||
feature_extractor=feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
# Enable group offloading
|
||||
onload_device = torch.device("cuda")
|
||||
offload_device = torch.device("cpu")
|
||||
list(map(
|
||||
lambda x: apply_group_offloading(x, onload_device, offload_device, offload_type="leaf_level", use_stream=True, low_cpu_mem_usage=True),
|
||||
[pipe.text_encoder, pipe.text_encoder_2, pipe.transformer]
|
||||
))
|
||||
pipe.image_encoder.to(onload_device)
|
||||
pipe.vae.to(onload_device)
|
||||
pipe.vae.enable_tiling()
|
||||
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/penguin.png"
|
||||
)
|
||||
output = pipe(
|
||||
image=image,
|
||||
prompt="A penguin dancing in the snow",
|
||||
height=832,
|
||||
width=480,
|
||||
num_frames=91,
|
||||
num_inference_steps=30,
|
||||
guidance_scale=9.0,
|
||||
generator=torch.Generator().manual_seed(0),
|
||||
sampling_type="vanilla",
|
||||
).frames[0]
|
||||
print(f"Max memory: {torch.cuda.max_memory_allocated() / 1024**3:.3f} GB")
|
||||
export_to_video(output, "output.mp4", fps=30)
|
||||
```
|
||||
|
||||
## HunyuanVideoFramepackPipeline
|
||||
|
||||
[[autodoc]] HunyuanVideoFramepackPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## HunyuanVideoPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.hunyuan_video.pipeline_output.HunyuanVideoPipelineOutput
|
||||
|
||||
@@ -31,12 +31,103 @@ Available models:
|
||||
|
||||
| Model name | Recommended dtype |
|
||||
|:-------------:|:-----------------:|
|
||||
| [`LTX Video 0.9.0`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.safetensors) | `torch.bfloat16` |
|
||||
| [`LTX Video 0.9.1`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) | `torch.bfloat16` |
|
||||
| [`LTX Video 0.9.5`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.5.safetensors) | `torch.bfloat16` |
|
||||
| [`LTX Video 2B 0.9.0`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.safetensors) | `torch.bfloat16` |
|
||||
| [`LTX Video 2B 0.9.1`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) | `torch.bfloat16` |
|
||||
| [`LTX Video 2B 0.9.5`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.5.safetensors) | `torch.bfloat16` |
|
||||
| [`LTX Video 13B 0.9.7`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-dev.safetensors) | `torch.bfloat16` |
|
||||
| [`LTX Video Spatial Upscaler 0.9.7`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-spatial-upscaler-0.9.7.safetensors) | `torch.bfloat16` |
|
||||
|
||||
Note: The recommended dtype is for the transformer component. The VAE and text encoders can be either `torch.float32`, `torch.bfloat16` or `torch.float16` but the recommended dtype is `torch.bfloat16` as used in the original repository.
|
||||
|
||||
## Recommended settings for generation
|
||||
|
||||
For the best results, it is recommended to follow the guidelines mentioned in the official LTX Video [repository](https://github.com/Lightricks/LTX-Video).
|
||||
|
||||
- Some variants of LTX Video are guidance-distilled. For guidance-distilled models, `guidance_scale` must be set to `1.0`. For any other models, `guidance_scale` should be set higher (e.g., `5.0`) for good generation quality.
|
||||
- For variants with a timestep-aware VAE (LTXV 0.9.1 and above), it is recommended to set `decode_timestep` to `0.05` and `image_cond_noise_scale` to `0.025`.
|
||||
- For variants that support interpolation between multiple conditioning images and videos (LTXV 0.9.5 and above), it is recommended to use similar looking images/videos for the best results. High divergence between the conditionings may lead to abrupt transitions in the generated video.
|
||||
|
||||
## Using LTX Video 13B 0.9.7
|
||||
|
||||
LTX Video 0.9.7 comes with a spatial latent upscaler and a 13B parameter transformer. The inference involves generating a low resolution video first, which is very fast, followed by upscaling and refining the generated video.
|
||||
|
||||
<!-- TODO(aryan): modify when official checkpoints are available -->
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import LTXConditionPipeline, LTXLatentUpsamplePipeline
|
||||
from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition
|
||||
from diffusers.utils import export_to_video, load_video
|
||||
|
||||
pipe = LTXConditionPipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.7-diffusers", torch_dtype=torch.bfloat16)
|
||||
pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.7-Latent-Spatial-Upsampler-diffusers", vae=pipe.vae, torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
pipe_upsample.to("cuda")
|
||||
pipe.vae.enable_tiling()
|
||||
|
||||
def round_to_nearest_resolution_acceptable_by_vae(height, width):
|
||||
height = height - (height % pipe.vae_temporal_compression_ratio)
|
||||
width = width - (width % pipe.vae_temporal_compression_ratio)
|
||||
return height, width
|
||||
|
||||
video = load_video(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input-vid.mp4"
|
||||
)[:21] # Use only the first 21 frames as conditioning
|
||||
condition1 = LTXVideoCondition(video=video, frame_index=0)
|
||||
|
||||
prompt = "The video depicts a winding mountain road covered in snow, with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the solitude and beauty of a winter drive through a mountainous region."
|
||||
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
|
||||
expected_height, expected_width = 768, 1152
|
||||
downscale_factor = 2 / 3
|
||||
num_frames = 161
|
||||
|
||||
# Part 1. Generate video at smaller resolution
|
||||
# Text-only conditioning is also supported without the need to pass `conditions`
|
||||
downscaled_height, downscaled_width = int(expected_height * downscale_factor), int(expected_width * downscale_factor)
|
||||
downscaled_height, downscaled_width = round_to_nearest_resolution_acceptable_by_vae(downscaled_height, downscaled_width)
|
||||
latents = pipe(
|
||||
conditions=[condition1],
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
width=downscaled_width,
|
||||
height=downscaled_height,
|
||||
num_frames=num_frames,
|
||||
num_inference_steps=30,
|
||||
generator=torch.Generator().manual_seed(0),
|
||||
output_type="latent",
|
||||
).frames
|
||||
|
||||
# Part 2. Upscale generated video using latent upsampler with fewer inference steps
|
||||
# The available latent upsampler upscales the height/width by 2x
|
||||
upscaled_height, upscaled_width = downscaled_height * 2, downscaled_width * 2
|
||||
upscaled_latents = pipe_upsample(
|
||||
latents=latents,
|
||||
output_type="latent"
|
||||
).frames
|
||||
|
||||
# Part 3. Denoise the upscaled video with few steps to improve texture (optional, but recommended)
|
||||
video = pipe(
|
||||
conditions=[condition1],
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
width=upscaled_width,
|
||||
height=upscaled_height,
|
||||
num_frames=num_frames,
|
||||
denoise_strength=0.4, # Effectively, 4 inference steps out of 10
|
||||
num_inference_steps=10,
|
||||
latents=upscaled_latents,
|
||||
decode_timestep=0.05,
|
||||
image_cond_noise_scale=0.025,
|
||||
generator=torch.Generator().manual_seed(0),
|
||||
output_type="pil",
|
||||
).frames[0]
|
||||
|
||||
# Part 4. Downscale the video to the expected resolution
|
||||
video = [frame.resize((expected_width, expected_height)) for frame in video]
|
||||
|
||||
export_to_video(video, "output.mp4", fps=24)
|
||||
```
|
||||
|
||||
## Loading Single Files
|
||||
|
||||
Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`]. We recommend using `from_single_file` for the Lightricks series of models, as they plan to release multiple models in the future in the single file format.
|
||||
@@ -204,6 +295,12 @@ export_to_video(video, "ship.mp4", fps=24)
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## LTXLatentUpsamplePipeline
|
||||
|
||||
[[autodoc]] LTXLatentUpsamplePipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## LTXPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput
|
||||
|
||||
@@ -89,6 +89,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
|
||||
| [UniDiffuser](unidiffuser) | text2image, image2text, image variation, text variation, unconditional image generation, unconditional audio generation |
|
||||
| [Value-guided planning](value_guided_sampling) | value guided sampling |
|
||||
| [Wuerstchen](wuerstchen) | text2image |
|
||||
| [VisualCloze](visualcloze) | text2image, image2image, subject driven generation, inpainting, style transfer, image restoration, image editing, [depth,normal,edge,pose]2image, [depth,normal,edge,pose]-estimation, virtual try-on, image relighting |
|
||||
|
||||
## DiffusionPipeline
|
||||
|
||||
|
||||
@@ -0,0 +1,300 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
-->
|
||||
|
||||
# VisualCloze
|
||||
|
||||
[VisualCloze: A Universal Image Generation Framework via Visual In-Context Learning](https://arxiv.org/abs/2504.07960) is an innovative in-context learning based universal image generation framework that offers key capabilities:
|
||||
1. Support for various in-domain tasks
|
||||
2. Generalization to unseen tasks through in-context learning
|
||||
3. Unify multiple tasks into one step and generate both target image and intermediate results
|
||||
4. Support reverse-engineering conditions from target images
|
||||
|
||||
## Overview
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*Recent progress in diffusion models significantly advances various image generation tasks. However, the current mainstream approach remains focused on building task-specific models, which have limited efficiency when supporting a wide range of different needs. While universal models attempt to address this limitation, they face critical challenges, including generalizable task instruction, appropriate task distributions, and unified architectural design. To tackle these challenges, we propose VisualCloze, a universal image generation framework, which supports a wide range of in-domain tasks, generalization to unseen ones, unseen unification of multiple tasks, and reverse generation. Unlike existing methods that rely on language-based task instruction, leading to task ambiguity and weak generalization, we integrate visual in-context learning, allowing models to identify tasks from visual demonstrations. Meanwhile, the inherent sparsity of visual task distributions hampers the learning of transferable knowledge across tasks. To this end, we introduce Graph200K, a graph-structured dataset that establishes various interrelated tasks, enhancing task density and transferable knowledge. Furthermore, we uncover that our unified image generation formulation shared a consistent objective with image infilling, enabling us to leverage the strong generative priors of pre-trained infilling models without modifying the architectures. The codes, dataset, and models are available at https://visualcloze.github.io.*
|
||||
|
||||
## Inference
|
||||
|
||||
### Model loading
|
||||
|
||||
VisualCloze is a two-stage cascade pipeline, containing `VisualClozeGenerationPipeline` and `VisualClozeUpsamplingPipeline`.
|
||||
- In `VisualClozeGenerationPipeline`, each image is downsampled before concatenating images into a grid layout, avoiding excessively high resolutions. VisualCloze releases two models suitable for diffusers, i.e., [VisualClozePipeline-384](https://huggingface.co/VisualCloze/VisualClozePipeline-384) and [VisualClozePipeline-512](https://huggingface.co/VisualCloze/VisualClozePipeline-384), which downsample images to resolutions of 384 and 512, respectively.
|
||||
- `VisualClozeUpsamplingPipeline` uses [SDEdit](https://arxiv.org/abs/2108.01073) to enable high-resolution image synthesis.
|
||||
|
||||
The `VisualClozePipeline` integrates both stages to support convenient end-to-end sampling, while also allowing users to utilize each pipeline independently as needed.
|
||||
|
||||
### Input Specifications
|
||||
|
||||
#### Task and Content Prompts
|
||||
- Task prompt: Required to describe the generation task intention
|
||||
- Content prompt: Optional description or caption of the target image
|
||||
- When content prompt is not needed, pass `None`
|
||||
- For batch inference, pass `List[str|None]`
|
||||
|
||||
#### Image Input Format
|
||||
- Format: `List[List[Image|None]]`
|
||||
- Structure:
|
||||
- All rows except the last represent in-context examples
|
||||
- Last row represents the current query (target image set to `None`)
|
||||
- For batch inference, pass `List[List[List[Image|None]]]`
|
||||
|
||||
#### Resolution Control
|
||||
- Default behavior:
|
||||
- Initial generation in the first stage: area of ${pipe.resolution}^2$
|
||||
- Upsampling in the second stage: 3x factor
|
||||
- Custom resolution: Adjust using `upsampling_height` and `upsampling_width` parameters
|
||||
|
||||
### Examples
|
||||
|
||||
For comprehensive examples covering a wide range of tasks, please refer to the [Online Demo](https://huggingface.co/spaces/VisualCloze/VisualCloze) and [GitHub Repository](https://github.com/lzyhha/VisualCloze). Below are simple examples for three cases: mask-to-image conversion, edge detection, and subject-driven generation.
|
||||
|
||||
#### Example for mask2image
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import VisualClozePipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
pipe = VisualClozePipeline.from_pretrained("VisualCloze/VisualClozePipeline-384", resolution=384, torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
# Load in-context images (make sure the paths are correct and accessible)
|
||||
image_paths = [
|
||||
# in-context examples
|
||||
[
|
||||
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_mask.jpg'),
|
||||
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_image.jpg'),
|
||||
],
|
||||
# query with the target image
|
||||
[
|
||||
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_query_mask.jpg'),
|
||||
None, # No image needed for the target image
|
||||
],
|
||||
]
|
||||
|
||||
# Task and content prompt
|
||||
task_prompt = "In each row, a logical task is demonstrated to achieve [IMAGE2] an aesthetically pleasing photograph based on [IMAGE1] sam 2-generated masks with rich color coding."
|
||||
content_prompt = """Majestic photo of a golden eagle perched on a rocky outcrop in a mountainous landscape.
|
||||
The eagle is positioned in the right foreground, facing left, with its sharp beak and keen eyes prominently visible.
|
||||
Its plumage is a mix of dark brown and golden hues, with intricate feather details.
|
||||
The background features a soft-focus view of snow-capped mountains under a cloudy sky, creating a serene and grandiose atmosphere.
|
||||
The foreground includes rugged rocks and patches of green moss. Photorealistic, medium depth of field,
|
||||
soft natural lighting, cool color palette, high contrast, sharp focus on the eagle, blurred background,
|
||||
tranquil, majestic, wildlife photography."""
|
||||
|
||||
# Run the pipeline
|
||||
image_result = pipe(
|
||||
task_prompt=task_prompt,
|
||||
content_prompt=content_prompt,
|
||||
image=image_paths,
|
||||
upsampling_width=1344,
|
||||
upsampling_height=768,
|
||||
upsampling_strength=0.4,
|
||||
guidance_scale=30,
|
||||
num_inference_steps=30,
|
||||
max_sequence_length=512,
|
||||
generator=torch.Generator("cpu").manual_seed(0)
|
||||
).images[0][0]
|
||||
|
||||
# Save the resulting image
|
||||
image_result.save("visualcloze.png")
|
||||
```
|
||||
|
||||
#### Example for edge-detection
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import VisualClozePipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
pipe = VisualClozePipeline.from_pretrained("VisualCloze/VisualClozePipeline-384", resolution=384, torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
# Load in-context images (make sure the paths are correct and accessible)
|
||||
image_paths = [
|
||||
# in-context examples
|
||||
[
|
||||
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_edgedetection_incontext-example-1_image.jpg'),
|
||||
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_edgedetection_incontext-example-1_edge.jpg'),
|
||||
],
|
||||
[
|
||||
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_edgedetection_incontext-example-2_image.jpg'),
|
||||
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_edgedetection_incontext-example-2_edge.jpg'),
|
||||
],
|
||||
# query with the target image
|
||||
[
|
||||
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_edgedetection_query_image.jpg'),
|
||||
None, # No image needed for the target image
|
||||
],
|
||||
]
|
||||
|
||||
# Task and content prompt
|
||||
task_prompt = "Each row illustrates a pathway from [IMAGE1] a sharp and beautifully composed photograph to [IMAGE2] edge map with natural well-connected outlines using a clear logical task."
|
||||
content_prompt = ""
|
||||
|
||||
# Run the pipeline
|
||||
image_result = pipe(
|
||||
task_prompt=task_prompt,
|
||||
content_prompt=content_prompt,
|
||||
image=image_paths,
|
||||
upsampling_width=864,
|
||||
upsampling_height=1152,
|
||||
upsampling_strength=0.4,
|
||||
guidance_scale=30,
|
||||
num_inference_steps=30,
|
||||
max_sequence_length=512,
|
||||
generator=torch.Generator("cpu").manual_seed(0)
|
||||
).images[0][0]
|
||||
|
||||
# Save the resulting image
|
||||
image_result.save("visualcloze.png")
|
||||
```
|
||||
|
||||
#### Example for subject-driven generation
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import VisualClozePipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
pipe = VisualClozePipeline.from_pretrained("VisualCloze/VisualClozePipeline-384", resolution=384, torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
# Load in-context images (make sure the paths are correct and accessible)
|
||||
image_paths = [
|
||||
# in-context examples
|
||||
[
|
||||
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-1_reference.jpg'),
|
||||
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-1_depth.jpg'),
|
||||
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-1_image.jpg'),
|
||||
],
|
||||
[
|
||||
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-2_reference.jpg'),
|
||||
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-2_depth.jpg'),
|
||||
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-2_image.jpg'),
|
||||
],
|
||||
# query with the target image
|
||||
[
|
||||
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_query_reference.jpg'),
|
||||
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_query_depth.jpg'),
|
||||
None, # No image needed for the target image
|
||||
],
|
||||
]
|
||||
|
||||
# Task and content prompt
|
||||
task_prompt = """Each row describes a process that begins with [IMAGE1] an image containing the key object,
|
||||
[IMAGE2] depth map revealing gray-toned spatial layers and results in
|
||||
[IMAGE3] an image with artistic qualitya high-quality image with exceptional detail."""
|
||||
content_prompt = """A vintage porcelain collector's item. Beneath a blossoming cherry tree in early spring,
|
||||
this treasure is photographed up close, with soft pink petals drifting through the air and vibrant blossoms framing the scene."""
|
||||
|
||||
# Run the pipeline
|
||||
image_result = pipe(
|
||||
task_prompt=task_prompt,
|
||||
content_prompt=content_prompt,
|
||||
image=image_paths,
|
||||
upsampling_width=1024,
|
||||
upsampling_height=1024,
|
||||
upsampling_strength=0.2,
|
||||
guidance_scale=30,
|
||||
num_inference_steps=30,
|
||||
max_sequence_length=512,
|
||||
generator=torch.Generator("cpu").manual_seed(0)
|
||||
).images[0][0]
|
||||
|
||||
# Save the resulting image
|
||||
image_result.save("visualcloze.png")
|
||||
```
|
||||
|
||||
#### Utilize each pipeline independently
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import VisualClozeGenerationPipeline, FluxFillPipeline as VisualClozeUpsamplingPipeline
|
||||
from diffusers.utils import load_image
|
||||
from PIL import Image
|
||||
|
||||
pipe = VisualClozeGenerationPipeline.from_pretrained(
|
||||
"VisualCloze/VisualClozePipeline-384", resolution=384, torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.to("cuda")
|
||||
|
||||
image_paths = [
|
||||
# in-context examples
|
||||
[
|
||||
load_image(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_mask.jpg"
|
||||
),
|
||||
load_image(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_image.jpg"
|
||||
),
|
||||
],
|
||||
# query with the target image
|
||||
[
|
||||
load_image(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_query_mask.jpg"
|
||||
),
|
||||
None, # No image needed for the target image
|
||||
],
|
||||
]
|
||||
task_prompt = "In each row, a logical task is demonstrated to achieve [IMAGE2] an aesthetically pleasing photograph based on [IMAGE1] sam 2-generated masks with rich color coding."
|
||||
content_prompt = "Majestic photo of a golden eagle perched on a rocky outcrop in a mountainous landscape. The eagle is positioned in the right foreground, facing left, with its sharp beak and keen eyes prominently visible. Its plumage is a mix of dark brown and golden hues, with intricate feather details. The background features a soft-focus view of snow-capped mountains under a cloudy sky, creating a serene and grandiose atmosphere. The foreground includes rugged rocks and patches of green moss. Photorealistic, medium depth of field, soft natural lighting, cool color palette, high contrast, sharp focus on the eagle, blurred background, tranquil, majestic, wildlife photography."
|
||||
|
||||
# Stage 1: Generate initial image
|
||||
image = pipe(
|
||||
task_prompt=task_prompt,
|
||||
content_prompt=content_prompt,
|
||||
image=image_paths,
|
||||
guidance_scale=30,
|
||||
num_inference_steps=30,
|
||||
max_sequence_length=512,
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
).images[0][0]
|
||||
|
||||
# Stage 2 (optional): Upsample the generated image
|
||||
pipe_upsample = VisualClozeUpsamplingPipeline.from_pipe(pipe)
|
||||
pipe_upsample.to("cuda")
|
||||
|
||||
mask_image = Image.new("RGB", image.size, (255, 255, 255))
|
||||
|
||||
image = pipe_upsample(
|
||||
image=image,
|
||||
mask_image=mask_image,
|
||||
prompt=content_prompt,
|
||||
width=1344,
|
||||
height=768,
|
||||
strength=0.4,
|
||||
guidance_scale=30,
|
||||
num_inference_steps=30,
|
||||
max_sequence_length=512,
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
).images[0]
|
||||
|
||||
image.save("visualcloze.png")
|
||||
```
|
||||
|
||||
## VisualClozePipeline
|
||||
|
||||
[[autodoc]] VisualClozePipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## VisualClozeGenerationPipeline
|
||||
|
||||
[[autodoc]] VisualClozeGenerationPipeline
|
||||
- all
|
||||
- __call__
|
||||
@@ -13,9 +13,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# Quantization
|
||||
|
||||
Quantization techniques reduce memory and computational costs by representing weights and activations with lower-precision data types like 8-bit integers (int8). This enables loading larger models you normally wouldn't be able to fit into memory, and speeding up inference. Diffusers supports 8-bit and 4-bit quantization with [bitsandbytes](https://huggingface.co/docs/bitsandbytes/en/index).
|
||||
|
||||
Quantization techniques that aren't supported in Transformers can be added with the [`DiffusersQuantizer`] class.
|
||||
Quantization techniques reduce memory and computational costs by representing weights and activations with lower-precision data types like 8-bit integers (int8). This enables loading larger models you normally wouldn't be able to fit into memory, and speeding up inference.
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -23,6 +21,9 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui
|
||||
|
||||
</Tip>
|
||||
|
||||
## PipelineQuantizationConfig
|
||||
|
||||
[[autodoc]] quantizers.PipelineQuantizationConfig
|
||||
|
||||
## BitsAndBytesConfig
|
||||
|
||||
|
||||
@@ -78,6 +78,23 @@ For more information and different options about `torch.compile`, refer to the [
|
||||
> [!TIP]
|
||||
> Learn more about other ways PyTorch 2.0 can help optimize your model in the [Accelerate inference of text-to-image diffusion models](../tutorials/fast_diffusion) tutorial.
|
||||
|
||||
### Regional compilation
|
||||
|
||||
Compiling the whole model usually has a big problem space for optimization. Models are often composed of multiple repeated blocks. [Regional compilation](https://pytorch.org/tutorials/recipes/regional_compilation.html) compiles the repeated block first (a transformer encoder block, for example), so that the Torch compiler would re-use its cached/optimized generated code for the other blocks, reducing (often massively) the cold start compilation time observed on the first inference call.
|
||||
|
||||
Enabling regional compilation might require simple yet intrusive changes to the
|
||||
modeling code. However, 🤗 Accelerate provides a utility [`compile_regions()`](https://huggingface.co/docs/accelerate/main/en/usage_guides/compilation#how-to-use-regional-compilation) which automatically compiles
|
||||
the repeated blocks of the provided `nn.Module` sequentially, and the rest of the model separately. This helps with reducing cold start time while keeping most (if not all) of the speedup you would get from full compilation.
|
||||
|
||||
```py
|
||||
# Make sure you're on the latest `accelerate`: `pip install -U accelerate`.
|
||||
from accelerate.utils import compile_regions
|
||||
|
||||
pipe.unet = compile_regions(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
```
|
||||
|
||||
As you may have noticed `compile_regions()` takes the same arguments as `torch.compile()`, allowing flexibility.
|
||||
|
||||
## Benchmark
|
||||
|
||||
We conducted a comprehensive benchmark with PyTorch 2.0's efficient attention implementation and `torch.compile` across different GPUs and batch sizes for five of our most used pipelines. The code is benchmarked on 🤗 Diffusers v0.17.0.dev0 to optimize `torch.compile` usage (see [here](https://github.com/huggingface/diffusers/pull/3313) for more details).
|
||||
|
||||
@@ -48,7 +48,7 @@ For Ada and higher-series GPUs. we recommend changing `torch_dtype` to `torch.bf
|
||||
```py
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
|
||||
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
|
||||
|
||||
import torch
|
||||
from diffusers import AutoModel
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
@@ -88,6 +88,8 @@ Setting `device_map="auto"` automatically fills all available space on the GPU(s
|
||||
CPU, and finally, the hard drive (the absolute slowest option) if there is still not enough memory.
|
||||
|
||||
```py
|
||||
from diffusers import FluxPipeline
|
||||
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
transformer=transformer_8bit,
|
||||
@@ -132,7 +134,7 @@ For Ada and higher-series GPUs. we recommend changing `torch_dtype` to `torch.bf
|
||||
```py
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
|
||||
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
|
||||
|
||||
import torch
|
||||
from diffusers import AutoModel
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
@@ -171,6 +173,8 @@ Let's generate an image using our quantized models.
|
||||
Setting `device_map="auto"` automatically fills all available space on the GPU(s) first, then the CPU, and finally, the hard drive (the absolute slowest option) if there is still not enough memory.
|
||||
|
||||
```py
|
||||
from diffusers import FluxPipeline
|
||||
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
transformer=transformer_4bit,
|
||||
@@ -214,6 +218,8 @@ Check your memory footprint with the `get_memory_footprint` method:
|
||||
print(model.get_memory_footprint())
|
||||
```
|
||||
|
||||
Note that this only tells you the memory footprint of the model params and does _not_ estimate the inference memory requirements.
|
||||
|
||||
Quantized models can be loaded from the [`~ModelMixin.from_pretrained`] method without needing to specify the `quantization_config` parameters:
|
||||
|
||||
```py
|
||||
@@ -413,4 +419,4 @@ transformer_4bit.dequantize()
|
||||
## Resources
|
||||
|
||||
* [End-to-end notebook showing Flux.1 Dev inference in a free-tier Colab](https://gist.github.com/sayakpaul/c76bd845b48759e11687ac550b99d8b4)
|
||||
* [Training](https://gist.github.com/sayakpaul/05afd428bc089b47af7c016e42004527)
|
||||
* [Training](https://github.com/huggingface/diffusers/blob/8c661ea586bf11cb2440da740dd3c4cf84679b85/examples/dreambooth/README_hidream.md#using-quantization)
|
||||
@@ -39,3 +39,90 @@ Diffusers currently supports the following quantization methods.
|
||||
- [Quanto](./quanto.md)
|
||||
|
||||
[This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
|
||||
|
||||
## Pipeline-level quantization
|
||||
|
||||
Diffusers allows users to directly initialize pipelines from checkpoints that may contain quantized models ([example](https://huggingface.co/hf-internal-testing/flux.1-dev-nf4-pkg)). However, users may want to apply
|
||||
quantization on-the-fly when initializing a pipeline from a pre-trained and non-quantized checkpoint. You can
|
||||
do this with [`~quantizers.PipelineQuantizationConfig`].
|
||||
|
||||
Start by defining a `PipelineQuantizationConfig`:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.quantizers.quantization_config import QuantoConfig
|
||||
from diffusers.quantizers import PipelineQuantizationConfig
|
||||
from transformers import BitsAndBytesConfig
|
||||
|
||||
pipeline_quant_config = PipelineQuantizationConfig(
|
||||
quant_mapping={
|
||||
"transformer": QuantoConfig(weights_dtype="int8"),
|
||||
"text_encoder_2": BitsAndBytesConfig(
|
||||
load_in_4bit=True, compute_dtype=torch.bfloat16
|
||||
),
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
Then pass it to [`~DiffusionPipeline.from_pretrained`] and run inference:
|
||||
|
||||
```py
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
quantization_config=pipeline_quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to("cuda")
|
||||
|
||||
image = pipe("photo of a cute dog").images[0]
|
||||
```
|
||||
|
||||
This method allows for more granular control over the quantization specifications of individual
|
||||
model-level components of a pipeline. It also allows for different quantization backends for
|
||||
different components. In the above example, you used a combination of Quanto and BitsandBytes. However,
|
||||
one caveat of this method is that users need to know which components come from `transformers` to be able
|
||||
to import the right quantization config class.
|
||||
|
||||
The other method is simpler in terms of experience but is
|
||||
less-flexible. Start by defining a `PipelineQuantizationConfig` but in a different way:
|
||||
|
||||
```py
|
||||
pipeline_quant_config = PipelineQuantizationConfig(
|
||||
quant_backend="bitsandbytes_4bit",
|
||||
quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
|
||||
components_to_quantize=["transformer", "text_encoder_2"],
|
||||
)
|
||||
```
|
||||
|
||||
This `pipeline_quant_config` can now be passed to [`~DiffusionPipeline.from_pretrained`] similar to the above example.
|
||||
|
||||
In this case, `quant_kwargs` will be used to initialize the quantization specifications
|
||||
of the respective quantization configuration class of `quant_backend`. `components_to_quantize`
|
||||
is used to denote the components that will be quantized. For most pipelines, you would want to
|
||||
keep `transformer` in the list as that is often the most compute and memory intensive.
|
||||
|
||||
The config below will work for most diffusion pipelines that have a `transformer` component present.
|
||||
In most case, you will want to quantize the `transformer` component as that is often the most compute-
|
||||
intensive part of a diffusion pipeline.
|
||||
|
||||
```py
|
||||
pipeline_quant_config = PipelineQuantizationConfig(
|
||||
quant_backend="bitsandbytes_4bit",
|
||||
quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
|
||||
components_to_quantize=["transformer"],
|
||||
)
|
||||
```
|
||||
|
||||
Below is a list of the supported quantization backends available in both `diffusers` and `transformers`:
|
||||
|
||||
* `bitsandbytes_4bit`
|
||||
* `bitsandbytes_8bit`
|
||||
* `gguf`
|
||||
* `quanto`
|
||||
* `torchao`
|
||||
|
||||
|
||||
Diffusion pipelines can have multiple text encoders. [`FluxPipeline`] has two, for example. It's
|
||||
recommended to quantize the text encoders that are memory-intensive. Some examples include T5,
|
||||
Llama, Gemma, etc. In the above example, you quantized the T5 model of [`FluxPipeline`] through
|
||||
`text_encoder_2` while keeping the CLIP model intact (accessible through `text_encoder`).
|
||||
@@ -430,6 +430,9 @@ def parse_args(input_args=None):
|
||||
default=4,
|
||||
help=("The dimension of the LoRA update matrices."),
|
||||
)
|
||||
|
||||
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
|
||||
|
||||
parser.add_argument(
|
||||
"--with_prior_preservation",
|
||||
default=False,
|
||||
@@ -1554,6 +1557,7 @@ def main(args):
|
||||
transformer_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
lora_dropout=args.lora_dropout,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=target_modules,
|
||||
)
|
||||
@@ -1562,6 +1566,7 @@ def main(args):
|
||||
text_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
lora_dropout=args.lora_dropout,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
)
|
||||
|
||||
@@ -658,6 +658,8 @@ def parse_args(input_args=None):
|
||||
default=4,
|
||||
help=("The dimension of the LoRA update matrices."),
|
||||
)
|
||||
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
|
||||
|
||||
parser.add_argument(
|
||||
"--use_dora",
|
||||
action="store_true",
|
||||
@@ -1248,6 +1250,7 @@ def main(args):
|
||||
unet_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
lora_dropout=args.lora_dropout,
|
||||
use_dora=args.use_dora,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||
@@ -1260,6 +1263,7 @@ def main(args):
|
||||
text_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
lora_dropout=args.lora_dropout,
|
||||
use_dora=args.use_dora,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
|
||||
@@ -767,6 +767,9 @@ def parse_args(input_args=None):
|
||||
default=4,
|
||||
help=("The dimension of the LoRA update matrices."),
|
||||
)
|
||||
|
||||
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
|
||||
|
||||
parser.add_argument(
|
||||
"--use_dora",
|
||||
action="store_true",
|
||||
@@ -1558,6 +1561,7 @@ def main(args):
|
||||
r=args.rank,
|
||||
use_dora=args.use_dora,
|
||||
lora_alpha=args.rank,
|
||||
lora_dropout=args.lora_dropout,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=target_modules,
|
||||
)
|
||||
@@ -1570,6 +1574,7 @@ def main(args):
|
||||
r=args.rank,
|
||||
use_dora=args.use_dora,
|
||||
lora_alpha=args.rank,
|
||||
lora_dropout=args.lora_dropout,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
)
|
||||
|
||||
@@ -1114,17 +1114,22 @@ def main(args):
|
||||
)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
|
||||
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
|
||||
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
|
||||
num_training_steps_for_scheduler = (
|
||||
args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
|
||||
)
|
||||
else:
|
||||
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
||||
num_warmup_steps=num_warmup_steps_for_scheduler,
|
||||
num_training_steps=num_training_steps_for_scheduler,
|
||||
num_cycles=args.lr_num_cycles,
|
||||
power=args.lr_power,
|
||||
)
|
||||
@@ -1156,8 +1161,14 @@ def main(args):
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if overrode_max_train_steps:
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
if num_training_steps_for_scheduler != args.max_train_steps:
|
||||
logger.warning(
|
||||
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
|
||||
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
|
||||
f"This inconsistency may result in the learning rate scheduler not functioning properly."
|
||||
)
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
|
||||
@@ -524,6 +524,9 @@ def parse_args(input_args=None):
|
||||
default=4,
|
||||
help=("The dimension of the LoRA update matrices."),
|
||||
)
|
||||
|
||||
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
|
||||
|
||||
parser.add_argument(
|
||||
"--image_interpolation_mode",
|
||||
type=str,
|
||||
@@ -932,6 +935,7 @@ def main(args):
|
||||
unet_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
lora_dropout=args.lora_dropout,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"],
|
||||
)
|
||||
@@ -942,6 +946,7 @@ def main(args):
|
||||
text_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
lora_dropout=args.lora_dropout,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
)
|
||||
|
||||
@@ -358,6 +358,9 @@ def parse_args(input_args=None):
|
||||
default=4,
|
||||
help=("The dimension of the LoRA update matrices."),
|
||||
)
|
||||
|
||||
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
|
||||
|
||||
parser.add_argument(
|
||||
"--with_prior_preservation",
|
||||
default=False,
|
||||
@@ -1236,6 +1239,7 @@ def main(args):
|
||||
transformer_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
lora_dropout=args.lora_dropout,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=target_modules,
|
||||
)
|
||||
@@ -1244,6 +1248,7 @@ def main(args):
|
||||
text_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
lora_dropout=args.lora_dropout,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
)
|
||||
|
||||
@@ -417,6 +417,9 @@ def parse_args(input_args=None):
|
||||
default=4,
|
||||
help=("The dimension of the LoRA update matrices."),
|
||||
)
|
||||
|
||||
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
|
||||
|
||||
parser.add_argument(
|
||||
"--with_prior_preservation",
|
||||
default=False,
|
||||
@@ -1161,6 +1164,7 @@ def main(args):
|
||||
transformer_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
lora_dropout=args.lora_dropout,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=target_modules,
|
||||
)
|
||||
|
||||
@@ -328,6 +328,9 @@ def parse_args(input_args=None):
|
||||
default=4,
|
||||
help=("The dimension of the LoRA update matrices."),
|
||||
)
|
||||
|
||||
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
|
||||
|
||||
parser.add_argument(
|
||||
"--with_prior_preservation",
|
||||
default=False,
|
||||
@@ -1023,6 +1026,7 @@ def main(args):
|
||||
transformer_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
lora_dropout=args.lora_dropout,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=target_modules,
|
||||
)
|
||||
|
||||
@@ -323,6 +323,9 @@ def parse_args(input_args=None):
|
||||
default=4,
|
||||
help=("The dimension of the LoRA update matrices."),
|
||||
)
|
||||
|
||||
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
|
||||
|
||||
parser.add_argument(
|
||||
"--with_prior_preservation",
|
||||
default=False,
|
||||
@@ -1021,6 +1024,7 @@ def main(args):
|
||||
transformer_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
lora_dropout=args.lora_dropout,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=target_modules,
|
||||
)
|
||||
|
||||
@@ -367,6 +367,9 @@ def parse_args(input_args=None):
|
||||
default=4,
|
||||
help=("The dimension of the LoRA update matrices."),
|
||||
)
|
||||
|
||||
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
|
||||
|
||||
parser.add_argument(
|
||||
"--with_prior_preservation",
|
||||
default=False,
|
||||
@@ -1264,6 +1267,7 @@ def main(args):
|
||||
transformer_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
lora_dropout=args.lora_dropout,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=target_modules,
|
||||
)
|
||||
@@ -1273,6 +1277,7 @@ def main(args):
|
||||
text_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
lora_dropout=args.lora_dropout,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
)
|
||||
|
||||
@@ -659,6 +659,9 @@ def parse_args(input_args=None):
|
||||
default=4,
|
||||
help=("The dimension of the LoRA update matrices."),
|
||||
)
|
||||
|
||||
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
|
||||
|
||||
parser.add_argument(
|
||||
"--use_dora",
|
||||
action="store_true",
|
||||
@@ -1199,10 +1202,11 @@ def main(args):
|
||||
text_encoder_one.gradient_checkpointing_enable()
|
||||
text_encoder_two.gradient_checkpointing_enable()
|
||||
|
||||
def get_lora_config(rank, use_dora, target_modules):
|
||||
def get_lora_config(rank, dropout, use_dora, target_modules):
|
||||
base_config = {
|
||||
"r": rank,
|
||||
"lora_alpha": rank,
|
||||
"lora_dropout": dropout,
|
||||
"init_lora_weights": "gaussian",
|
||||
"target_modules": target_modules,
|
||||
}
|
||||
@@ -1218,14 +1222,24 @@ def main(args):
|
||||
|
||||
# now we will add new LoRA weights to the attention layers
|
||||
unet_target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
|
||||
unet_lora_config = get_lora_config(rank=args.rank, use_dora=args.use_dora, target_modules=unet_target_modules)
|
||||
unet_lora_config = get_lora_config(
|
||||
rank=args.rank,
|
||||
dropout=args.lora_dropout,
|
||||
use_dora=args.use_dora,
|
||||
target_modules=unet_target_modules,
|
||||
)
|
||||
unet.add_adapter(unet_lora_config)
|
||||
|
||||
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
|
||||
# So, instead, we monkey-patch the forward calls of its attention-blocks.
|
||||
if args.train_text_encoder:
|
||||
text_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
|
||||
text_lora_config = get_lora_config(rank=args.rank, use_dora=args.use_dora, target_modules=text_target_modules)
|
||||
text_lora_config = get_lora_config(
|
||||
rank=args.rank,
|
||||
dropout=args.lora_dropout,
|
||||
use_dora=args.use_dora,
|
||||
target_modules=text_target_modules,
|
||||
)
|
||||
text_encoder_one.add_adapter(text_lora_config)
|
||||
text_encoder_two.add_adapter(text_lora_config)
|
||||
|
||||
|
||||
@@ -812,7 +812,7 @@ def main(args):
|
||||
|
||||
if args.scale_lr:
|
||||
args.learning_rate = (
|
||||
args.learning_rat * args.gradient_accumulation_steps * args.per_gpu_batch_size * accelerator.num_processes
|
||||
args.learning_rate * args.gradient_accumulation_steps * args.per_gpu_batch_size * accelerator.num_processes
|
||||
)
|
||||
|
||||
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
||||
|
||||
@@ -0,0 +1,95 @@
|
||||
# Training SANA Sprint Diffuser
|
||||
|
||||
This README explains how to use the provided bash script commands to download a pre-trained teacher diffuser model and train it on a specific dataset, following the [SANA Sprint methodology](https://arxiv.org/abs/2503.09641).
|
||||
|
||||
|
||||
## Setup
|
||||
|
||||
### 1. Define the local paths
|
||||
|
||||
Set a variable for your desired output directory. This directory will store the downloaded model and the training checkpoints/results.
|
||||
|
||||
```bash
|
||||
your_local_path='output' # Or any other path you prefer
|
||||
mkdir -p $your_local_path # Create the directory if it doesn't exist
|
||||
```
|
||||
|
||||
### 2. Download the pre-trained model
|
||||
|
||||
Download the SANA Sprint teacher model from Hugging Face Hub. The script uses the 1.6B parameter model.
|
||||
|
||||
```bash
|
||||
huggingface-cli download Efficient-Large-Model/SANA_Sprint_1.6B_1024px_teacher_diffusers --local-dir $your_local_path/SANA_Sprint_1.6B_1024px_teacher_diffusers
|
||||
```
|
||||
|
||||
*(Optional: You can also download the 0.6B model by replacing the model name: `Efficient-Large-Model/Sana_Sprint_0.6B_1024px_teacher_diffusers`)*
|
||||
|
||||
### 3. Acquire the dataset shards
|
||||
|
||||
The training script in this example uses specific `.parquet` shards from a randomly selected `brivangl/midjourney-v6-llava` dataset instead of downloading the entire dataset automatically via `dataset_name`.
|
||||
|
||||
The script specifically uses these three files:
|
||||
* `data/train_000.parquet`
|
||||
* `data/train_001.parquet`
|
||||
* `data/train_002.parquet`
|
||||
|
||||
|
||||
|
||||
You can either:
|
||||
|
||||
Let the script download the dataset automatically during first run
|
||||
|
||||
Or download it manually
|
||||
|
||||
**Note:** The full `brivangl/midjourney-v6-llava` dataset is much larger and contains many more shards. This script example explicitly trains *only* on the three specified shards.
|
||||
|
||||
## Usage
|
||||
|
||||
Once the model is downloaded, you can run the training script.
|
||||
|
||||
```bash
|
||||
|
||||
your_local_path='output' # Ensure this variable is set
|
||||
|
||||
python train_sana_sprint_diffusers.py \
|
||||
--pretrained_model_name_or_path=$your_local_path/SANA_Sprint_1.6B_1024px_teacher_diffusers \
|
||||
--output_dir=$your_local_path \
|
||||
--mixed_precision=bf16 \
|
||||
--resolution=1024 \
|
||||
--learning_rate=1e-6 \
|
||||
--max_train_steps=30000 \
|
||||
--dataloader_num_workers=8 \
|
||||
--dataset_name='brivangl/midjourney-v6-llava' \
|
||||
--file_path data/train_000.parquet data/train_001.parquet data/train_002.parquet \
|
||||
--checkpointing_steps=500 --checkpoints_total_limit=10 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=1 \
|
||||
--seed=453645634 \
|
||||
--train_largest_timestep \
|
||||
--misaligned_pairs_D \
|
||||
--gradient_checkpointing \
|
||||
--resume_from_checkpoint="latest" \
|
||||
```
|
||||
|
||||
### Explanation of parameters
|
||||
|
||||
* `--pretrained_model_name_or_path`: Path to the downloaded pre-trained model directory.
|
||||
* `--output_dir`: Directory where training logs, checkpoints, and the final model will be saved.
|
||||
* `--mixed_precision`: Use BF16 mixed precision for training, which can save memory and speed up training on compatible hardware.
|
||||
* `--resolution`: The image resolution used for training (1024x1024).
|
||||
* `--learning_rate`: The learning rate for the optimizer.
|
||||
* `--max_train_steps`: The total number of training steps to perform.
|
||||
* `--dataloader_num_workers`: Number of worker processes for loading data. Increase for faster data loading if your CPU and disk can handle it.
|
||||
* `--dataset_name`: The name of the dataset on Hugging Face Hub (`brivangl/midjourney-v6-llava`).
|
||||
* `--file_path`: **Specifies the local paths to the dataset shards to be used for training.** In this case, `data/train_000.parquet`, `data/train_001.parquet`, and `data/train_002.parquet`.
|
||||
* `--checkpointing_steps`: Save a training checkpoint every X steps.
|
||||
* `--checkpoints_total_limit`: Maximum number of checkpoints to keep. Older checkpoints will be deleted.
|
||||
* `--train_batch_size`: The batch size per GPU.
|
||||
* `--gradient_accumulation_steps`: Number of steps to accumulate gradients before performing an optimizer step.
|
||||
* `--seed`: Random seed for reproducibility.
|
||||
* `--train_largest_timestep`: A specific training strategy focusing on larger timesteps.
|
||||
* `--misaligned_pairs_D`: Another specific training strategy to add misaligned image-text pairs as fake data for GAN.
|
||||
* `--gradient_checkpointing`: Enable gradient checkpointing to save GPU memory.
|
||||
* `--resume_from_checkpoint`: Allows resuming training from the latest saved checkpoint in the `--output_dir`.
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,26 @@
|
||||
your_local_path='output'
|
||||
|
||||
huggingface-cli download Efficient-Large-Model/SANA_Sprint_1.6B_1024px_teacher_diffusers --local-dir $your_local_path/SANA_Sprint_1.6B_1024px_teacher_diffusers
|
||||
|
||||
# or Sana_Sprint_0.6B_1024px_teacher_diffusers
|
||||
|
||||
python train_sana_sprint_diffusers.py \
|
||||
--pretrained_model_name_or_path=$your_local_path/SANA_Sprint_1.6B_1024px_teacher_diffusers \
|
||||
--output_dir=$your_local_path \
|
||||
--mixed_precision=bf16 \
|
||||
--resolution=1024 \
|
||||
--learning_rate=1e-6 \
|
||||
--max_train_steps=30000 \
|
||||
--dataloader_num_workers=8 \
|
||||
--dataset_name='brivangl/midjourney-v6-llava' \
|
||||
--file_path data/train_000.parquet data/train_001.parquet data/train_002.parquet \
|
||||
--checkpointing_steps=500 --checkpoints_total_limit=10 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=1 \
|
||||
--seed=453645634 \
|
||||
--train_largest_timestep \
|
||||
--misaligned_pairs_D \
|
||||
--gradient_checkpointing \
|
||||
--resume_from_checkpoint="latest" \
|
||||
|
||||
|
||||
@@ -0,0 +1,352 @@
|
||||
import argparse
|
||||
import pathlib
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from diffusers import AutoencoderKLCosmos, CosmosTextToWorldPipeline, CosmosTransformer3DModel, EDMEulerScheduler
|
||||
|
||||
|
||||
def remove_keys_(key: str, state_dict: Dict[str, Any]):
|
||||
state_dict.pop(key)
|
||||
|
||||
|
||||
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
|
||||
def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
|
||||
block_index = int(key.split(".")[1].removeprefix("block"))
|
||||
new_key = key
|
||||
|
||||
old_prefix = f"blocks.block{block_index}"
|
||||
new_prefix = f"transformer_blocks.{block_index}"
|
||||
new_key = new_prefix + new_key.removeprefix(old_prefix)
|
||||
|
||||
state_dict[new_key] = state_dict.pop(key)
|
||||
|
||||
|
||||
TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"t_embedder.1": "time_embed.t_embedder",
|
||||
"affline_norm": "time_embed.norm",
|
||||
".blocks.0.block.attn": ".attn1",
|
||||
".blocks.1.block.attn": ".attn2",
|
||||
".blocks.2.block": ".ff",
|
||||
".blocks.0.adaLN_modulation.1": ".norm1.linear_1",
|
||||
".blocks.0.adaLN_modulation.2": ".norm1.linear_2",
|
||||
".blocks.1.adaLN_modulation.1": ".norm2.linear_1",
|
||||
".blocks.1.adaLN_modulation.2": ".norm2.linear_2",
|
||||
".blocks.2.adaLN_modulation.1": ".norm3.linear_1",
|
||||
".blocks.2.adaLN_modulation.2": ".norm3.linear_2",
|
||||
"to_q.0": "to_q",
|
||||
"to_q.1": "norm_q",
|
||||
"to_k.0": "to_k",
|
||||
"to_k.1": "norm_k",
|
||||
"to_v.0": "to_v",
|
||||
"layer1": "net.0.proj",
|
||||
"layer2": "net.2",
|
||||
"proj.1": "proj",
|
||||
"x_embedder": "patch_embed",
|
||||
"extra_pos_embedder": "learnable_pos_embed",
|
||||
"final_layer.adaLN_modulation.1": "norm_out.linear_1",
|
||||
"final_layer.adaLN_modulation.2": "norm_out.linear_2",
|
||||
"final_layer.linear": "proj_out",
|
||||
}
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"blocks.block": rename_transformer_blocks_,
|
||||
"logvar.0.freqs": remove_keys_,
|
||||
"logvar.0.phases": remove_keys_,
|
||||
"logvar.1.weight": remove_keys_,
|
||||
"pos_embedder.seq": remove_keys_,
|
||||
}
|
||||
|
||||
TRANSFORMER_CONFIGS = {
|
||||
"Cosmos-1.0-Diffusion-7B-Text2World": {
|
||||
"in_channels": 16,
|
||||
"out_channels": 16,
|
||||
"num_attention_heads": 32,
|
||||
"attention_head_dim": 128,
|
||||
"num_layers": 28,
|
||||
"mlp_ratio": 4.0,
|
||||
"text_embed_dim": 1024,
|
||||
"adaln_lora_dim": 256,
|
||||
"max_size": (128, 240, 240),
|
||||
"patch_size": (1, 2, 2),
|
||||
"rope_scale": (2.0, 1.0, 1.0),
|
||||
"concat_padding_mask": True,
|
||||
"extra_pos_embed_type": "learnable",
|
||||
},
|
||||
"Cosmos-1.0-Diffusion-7B-Video2World": {
|
||||
"in_channels": 16 + 1,
|
||||
"out_channels": 16,
|
||||
"num_attention_heads": 32,
|
||||
"attention_head_dim": 128,
|
||||
"num_layers": 28,
|
||||
"mlp_ratio": 4.0,
|
||||
"text_embed_dim": 1024,
|
||||
"adaln_lora_dim": 256,
|
||||
"max_size": (128, 240, 240),
|
||||
"patch_size": (1, 2, 2),
|
||||
"rope_scale": (2.0, 1.0, 1.0),
|
||||
"concat_padding_mask": True,
|
||||
"extra_pos_embed_type": "learnable",
|
||||
},
|
||||
"Cosmos-1.0-Diffusion-14B-Text2World": {
|
||||
"in_channels": 16,
|
||||
"out_channels": 16,
|
||||
"num_attention_heads": 40,
|
||||
"attention_head_dim": 128,
|
||||
"num_layers": 36,
|
||||
"mlp_ratio": 4.0,
|
||||
"text_embed_dim": 1024,
|
||||
"adaln_lora_dim": 256,
|
||||
"max_size": (128, 240, 240),
|
||||
"patch_size": (1, 2, 2),
|
||||
"rope_scale": (2.0, 2.0, 2.0),
|
||||
"concat_padding_mask": True,
|
||||
"extra_pos_embed_type": "learnable",
|
||||
},
|
||||
"Cosmos-1.0-Diffusion-14B-Video2World": {
|
||||
"in_channels": 16 + 1,
|
||||
"out_channels": 16,
|
||||
"num_attention_heads": 40,
|
||||
"attention_head_dim": 128,
|
||||
"num_layers": 36,
|
||||
"mlp_ratio": 4.0,
|
||||
"text_embed_dim": 1024,
|
||||
"adaln_lora_dim": 256,
|
||||
"max_size": (128, 240, 240),
|
||||
"patch_size": (1, 2, 2),
|
||||
"rope_scale": (2.0, 2.0, 2.0),
|
||||
"concat_padding_mask": True,
|
||||
"extra_pos_embed_type": "learnable",
|
||||
},
|
||||
}
|
||||
|
||||
VAE_KEYS_RENAME_DICT = {
|
||||
"down.0": "down_blocks.0",
|
||||
"down.1": "down_blocks.1",
|
||||
"down.2": "down_blocks.2",
|
||||
"up.0": "up_blocks.2",
|
||||
"up.1": "up_blocks.1",
|
||||
"up.2": "up_blocks.0",
|
||||
".block.": ".resnets.",
|
||||
"downsample": "downsamplers.0",
|
||||
"upsample": "upsamplers.0",
|
||||
"mid.block_1": "mid_block.resnets.0",
|
||||
"mid.attn_1.0": "mid_block.attentions.0",
|
||||
"mid.attn_1.1": "mid_block.temp_attentions.0",
|
||||
"mid.block_2": "mid_block.resnets.1",
|
||||
".q.conv3d": ".to_q",
|
||||
".k.conv3d": ".to_k",
|
||||
".v.conv3d": ".to_v",
|
||||
".proj_out.conv3d": ".to_out.0",
|
||||
".0.conv3d": ".conv_s",
|
||||
".1.conv3d": ".conv_t",
|
||||
"conv1.conv3d": "conv1",
|
||||
"conv2.conv3d": "conv2",
|
||||
"conv3.conv3d": "conv3",
|
||||
"nin_shortcut.conv3d": "conv_shortcut",
|
||||
"quant_conv.conv3d": "quant_conv",
|
||||
"post_quant_conv.conv3d": "post_quant_conv",
|
||||
}
|
||||
|
||||
VAE_SPECIAL_KEYS_REMAP = {
|
||||
"wavelets": remove_keys_,
|
||||
"_arange": remove_keys_,
|
||||
"patch_size_buffer": remove_keys_,
|
||||
}
|
||||
|
||||
VAE_CONFIGS = {
|
||||
"CV8x8x8-0.1": {
|
||||
"name": "nvidia/Cosmos-0.1-Tokenizer-CV8x8x8",
|
||||
"diffusers_config": {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 16,
|
||||
"encoder_block_out_channels": (128, 256, 512, 512),
|
||||
"decode_block_out_channels": (256, 512, 512, 512),
|
||||
"attention_resolutions": (32,),
|
||||
"resolution": 1024,
|
||||
"num_layers": 2,
|
||||
"patch_size": 4,
|
||||
"patch_type": "haar",
|
||||
"scaling_factor": 1.0,
|
||||
"spatial_compression_ratio": 8,
|
||||
"temporal_compression_ratio": 8,
|
||||
"latents_mean": None,
|
||||
"latents_std": None,
|
||||
},
|
||||
},
|
||||
"CV8x8x8-1.0": {
|
||||
"name": "nvidia/Cosmos-1.0-Tokenizer-CV8x8x8",
|
||||
"diffusers_config": {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 16,
|
||||
"encoder_block_out_channels": (128, 256, 512, 512),
|
||||
"decode_block_out_channels": (256, 512, 512, 512),
|
||||
"attention_resolutions": (32,),
|
||||
"resolution": 1024,
|
||||
"num_layers": 2,
|
||||
"patch_size": 4,
|
||||
"patch_type": "haar",
|
||||
"scaling_factor": 1.0,
|
||||
"spatial_compression_ratio": 8,
|
||||
"temporal_compression_ratio": 8,
|
||||
"latents_mean": None,
|
||||
"latents_std": None,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
state_dict = saved_dict
|
||||
if "model" in saved_dict.keys():
|
||||
state_dict = state_dict["model"]
|
||||
if "module" in saved_dict.keys():
|
||||
state_dict = state_dict["module"]
|
||||
if "state_dict" in saved_dict.keys():
|
||||
state_dict = state_dict["state_dict"]
|
||||
return state_dict
|
||||
|
||||
|
||||
def convert_transformer(transformer_type: str, ckpt_path: str):
|
||||
PREFIX_KEY = "net."
|
||||
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True))
|
||||
|
||||
with init_empty_weights():
|
||||
config = TRANSFORMER_CONFIGS[transformer_type]
|
||||
transformer = CosmosTransformer3DModel(**config)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
if new_key.startswith(PREFIX_KEY):
|
||||
new_key = new_key.removeprefix(PREFIX_KEY)
|
||||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_(original_state_dict, key, new_key)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return transformer
|
||||
|
||||
|
||||
def convert_vae(vae_type: str):
|
||||
model_name = VAE_CONFIGS[vae_type]["name"]
|
||||
snapshot_directory = snapshot_download(model_name, repo_type="model")
|
||||
directory = pathlib.Path(snapshot_directory)
|
||||
|
||||
autoencoder_file = directory / "autoencoder.jit"
|
||||
mean_std_file = directory / "mean_std.pt"
|
||||
|
||||
original_state_dict = torch.jit.load(autoencoder_file.as_posix()).state_dict()
|
||||
if mean_std_file.exists():
|
||||
mean_std = torch.load(mean_std_file, map_location="cpu", weights_only=True)
|
||||
else:
|
||||
mean_std = (None, None)
|
||||
|
||||
config = VAE_CONFIGS[vae_type]["diffusers_config"]
|
||||
config.update(
|
||||
{
|
||||
"latents_mean": mean_std[0].detach().cpu().numpy().tolist(),
|
||||
"latents_std": mean_std[1].detach().cpu().numpy().tolist(),
|
||||
}
|
||||
)
|
||||
vae = AutoencoderKLCosmos(**config)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_(original_state_dict, key, new_key)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
vae.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return vae
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--transformer_type", type=str, default=None, choices=list(TRANSFORMER_CONFIGS.keys()))
|
||||
parser.add_argument(
|
||||
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
|
||||
)
|
||||
parser.add_argument("--vae_type", type=str, default=None, choices=list(VAE_CONFIGS.keys()), help="Type of VAE")
|
||||
parser.add_argument("--text_encoder_path", type=str, default="google-t5/t5-11b")
|
||||
parser.add_argument("--tokenizer_path", type=str, default="google-t5/t5-11b")
|
||||
parser.add_argument("--save_pipeline", action="store_true")
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
DTYPE_MAPPING = {
|
||||
"fp32": torch.float32,
|
||||
"fp16": torch.float16,
|
||||
"bf16": torch.bfloat16,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
transformer = None
|
||||
dtype = DTYPE_MAPPING[args.dtype]
|
||||
|
||||
if args.save_pipeline:
|
||||
assert args.transformer_ckpt_path is not None
|
||||
assert args.vae_type is not None
|
||||
assert args.text_encoder_path is not None
|
||||
assert args.tokenizer_path is not None
|
||||
|
||||
if args.transformer_ckpt_path is not None:
|
||||
transformer = convert_transformer(args.transformer_type, args.transformer_ckpt_path)
|
||||
transformer = transformer.to(dtype=dtype)
|
||||
if not args.save_pipeline:
|
||||
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
if args.vae_type is not None:
|
||||
vae = convert_vae(args.vae_type)
|
||||
if not args.save_pipeline:
|
||||
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
if args.save_pipeline:
|
||||
text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_path, torch_dtype=dtype)
|
||||
tokenizer = T5TokenizerFast.from_pretrained(args.tokenizer_path)
|
||||
# The original code initializes EDM config with sigma_min=0.0002, but does not make use of it anywhere directly.
|
||||
# So, the sigma_min values that is used is the default value of 0.002.
|
||||
scheduler = EDMEulerScheduler(
|
||||
sigma_min=0.002,
|
||||
sigma_max=80,
|
||||
sigma_data=0.5,
|
||||
sigma_schedule="karras",
|
||||
num_train_timesteps=1000,
|
||||
prediction_type="epsilon",
|
||||
rho=7.0,
|
||||
final_sigmas_type="sigma_min",
|
||||
)
|
||||
|
||||
pipe = CosmosTextToWorldPipeline(
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
@@ -7,7 +7,15 @@ from accelerate import init_empty_weights
|
||||
from safetensors.torch import load_file
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
|
||||
from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LTXPipeline, LTXVideoTransformer3DModel
|
||||
from diffusers import (
|
||||
AutoencoderKLLTXVideo,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
LTXConditionPipeline,
|
||||
LTXLatentUpsamplePipeline,
|
||||
LTXPipeline,
|
||||
LTXVideoTransformer3DModel,
|
||||
)
|
||||
from diffusers.pipelines.ltx.modeling_latent_upsampler import LTXLatentUpsamplerModel
|
||||
|
||||
|
||||
def remove_keys_(key: str, state_dict: Dict[str, Any]):
|
||||
@@ -123,17 +131,10 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
|
||||
def convert_transformer(
|
||||
ckpt_path: str,
|
||||
dtype: torch.dtype,
|
||||
version: str = "0.9.0",
|
||||
):
|
||||
def convert_transformer(ckpt_path: str, config, dtype: torch.dtype):
|
||||
PREFIX_KEY = "model.diffusion_model."
|
||||
|
||||
original_state_dict = get_state_dict(load_file(ckpt_path))
|
||||
config = {}
|
||||
if version == "0.9.5":
|
||||
config["_use_causal_rope_fix"] = True
|
||||
with init_empty_weights():
|
||||
transformer = LTXVideoTransformer3DModel(**config)
|
||||
|
||||
@@ -180,8 +181,59 @@ def convert_vae(ckpt_path: str, config, dtype: torch.dtype):
|
||||
return vae
|
||||
|
||||
|
||||
def convert_spatial_latent_upsampler(ckpt_path: str, config, dtype: torch.dtype):
|
||||
original_state_dict = get_state_dict(load_file(ckpt_path))
|
||||
|
||||
with init_empty_weights():
|
||||
latent_upsampler = LTXLatentUpsamplerModel(**config)
|
||||
|
||||
latent_upsampler.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
latent_upsampler.to(dtype)
|
||||
return latent_upsampler
|
||||
|
||||
|
||||
def get_transformer_config(version: str) -> Dict[str, Any]:
|
||||
if version == "0.9.7":
|
||||
config = {
|
||||
"in_channels": 128,
|
||||
"out_channels": 128,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"num_attention_heads": 32,
|
||||
"attention_head_dim": 128,
|
||||
"cross_attention_dim": 4096,
|
||||
"num_layers": 48,
|
||||
"activation_fn": "gelu-approximate",
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"norm_elementwise_affine": False,
|
||||
"norm_eps": 1e-6,
|
||||
"caption_channels": 4096,
|
||||
"attention_bias": True,
|
||||
"attention_out_bias": True,
|
||||
}
|
||||
else:
|
||||
config = {
|
||||
"in_channels": 128,
|
||||
"out_channels": 128,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"num_attention_heads": 32,
|
||||
"attention_head_dim": 64,
|
||||
"cross_attention_dim": 2048,
|
||||
"num_layers": 28,
|
||||
"activation_fn": "gelu-approximate",
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"norm_elementwise_affine": False,
|
||||
"norm_eps": 1e-6,
|
||||
"caption_channels": 4096,
|
||||
"attention_bias": True,
|
||||
"attention_out_bias": True,
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def get_vae_config(version: str) -> Dict[str, Any]:
|
||||
if version == "0.9.0":
|
||||
if version in ["0.9.0"]:
|
||||
config = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
@@ -210,7 +262,7 @@ def get_vae_config(version: str) -> Dict[str, Any]:
|
||||
"decoder_causal": False,
|
||||
"timestep_conditioning": False,
|
||||
}
|
||||
elif version == "0.9.1":
|
||||
elif version in ["0.9.1"]:
|
||||
config = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
@@ -240,7 +292,39 @@ def get_vae_config(version: str) -> Dict[str, Any]:
|
||||
"decoder_causal": False,
|
||||
}
|
||||
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
|
||||
elif version == "0.9.5":
|
||||
elif version in ["0.9.5"]:
|
||||
config = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (128, 256, 512, 1024, 2048),
|
||||
"down_block_types": (
|
||||
"LTXVideo095DownBlock3D",
|
||||
"LTXVideo095DownBlock3D",
|
||||
"LTXVideo095DownBlock3D",
|
||||
"LTXVideo095DownBlock3D",
|
||||
),
|
||||
"decoder_block_out_channels": (256, 512, 1024),
|
||||
"layers_per_block": (4, 6, 6, 2, 2),
|
||||
"decoder_layers_per_block": (5, 5, 5, 5),
|
||||
"spatio_temporal_scaling": (True, True, True, True),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True),
|
||||
"decoder_inject_noise": (False, False, False, False),
|
||||
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": True,
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
"scaling_factor": 1.0,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
"spatial_compression_ratio": 32,
|
||||
"temporal_compression_ratio": 8,
|
||||
}
|
||||
VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
|
||||
elif version in ["0.9.7"]:
|
||||
config = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
@@ -275,12 +359,33 @@ def get_vae_config(version: str) -> Dict[str, Any]:
|
||||
return config
|
||||
|
||||
|
||||
def get_spatial_latent_upsampler_config(version: str) -> Dict[str, Any]:
|
||||
if version == "0.9.7":
|
||||
config = {
|
||||
"in_channels": 128,
|
||||
"mid_channels": 512,
|
||||
"num_blocks_per_stage": 4,
|
||||
"dims": 3,
|
||||
"spatial_upsample": True,
|
||||
"temporal_upsample": False,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported version: {version}")
|
||||
return config
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
|
||||
)
|
||||
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
|
||||
parser.add_argument(
|
||||
"--spatial_latent_upsampler_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to original spatial latent upsampler checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
|
||||
)
|
||||
@@ -294,7 +399,11 @@ def get_args():
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
|
||||
parser.add_argument(
|
||||
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1", "0.9.5"], help="Version of the LTX model"
|
||||
"--version",
|
||||
type=str,
|
||||
default="0.9.0",
|
||||
choices=["0.9.0", "0.9.1", "0.9.5", "0.9.7"],
|
||||
help="Version of the LTX model",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
@@ -320,11 +429,9 @@ if __name__ == "__main__":
|
||||
variant = VARIANT_MAPPING[args.dtype]
|
||||
output_path = Path(args.output_path)
|
||||
|
||||
if args.save_pipeline:
|
||||
assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None
|
||||
|
||||
if args.transformer_ckpt_path is not None:
|
||||
transformer: LTXVideoTransformer3DModel = convert_transformer(args.transformer_ckpt_path, dtype)
|
||||
config = get_transformer_config(args.version)
|
||||
transformer: LTXVideoTransformer3DModel = convert_transformer(args.transformer_ckpt_path, config, dtype)
|
||||
if not args.save_pipeline:
|
||||
transformer.save_pretrained(
|
||||
output_path / "transformer", safe_serialization=True, max_shard_size="5GB", variant=variant
|
||||
@@ -336,6 +443,16 @@ if __name__ == "__main__":
|
||||
if not args.save_pipeline:
|
||||
vae.save_pretrained(output_path / "vae", safe_serialization=True, max_shard_size="5GB", variant=variant)
|
||||
|
||||
if args.spatial_latent_upsampler_path is not None:
|
||||
config = get_spatial_latent_upsampler_config(args.version)
|
||||
latent_upsampler: LTXLatentUpsamplerModel = convert_spatial_latent_upsampler(
|
||||
args.spatial_latent_upsampler_path, config, dtype
|
||||
)
|
||||
if not args.save_pipeline:
|
||||
latent_upsampler.save_pretrained(
|
||||
output_path / "latent_upsampler", safe_serialization=True, max_shard_size="5GB", variant=variant
|
||||
)
|
||||
|
||||
if args.save_pipeline:
|
||||
text_encoder_id = "google/t5-v1_1-xxl"
|
||||
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
|
||||
@@ -348,7 +465,7 @@ if __name__ == "__main__":
|
||||
for param in text_encoder.parameters():
|
||||
param.data = param.data.contiguous()
|
||||
|
||||
if args.version == "0.9.5":
|
||||
if args.version in ["0.9.5", "0.9.7"]:
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(use_dynamic_shifting=False)
|
||||
else:
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
@@ -360,12 +477,40 @@ if __name__ == "__main__":
|
||||
shift_terminal=0.1,
|
||||
)
|
||||
|
||||
pipe = LTXPipeline(
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
)
|
||||
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, variant=variant, max_shard_size="5GB")
|
||||
if args.version in ["0.9.0", "0.9.1", "0.9.5"]:
|
||||
pipe = LTXPipeline(
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
)
|
||||
pipe.save_pretrained(
|
||||
output_path.as_posix(), safe_serialization=True, variant=variant, max_shard_size="5GB"
|
||||
)
|
||||
elif args.version in ["0.9.7"]:
|
||||
pipe = LTXConditionPipeline(
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
)
|
||||
pipe_upsample = LTXLatentUpsamplePipeline(
|
||||
vae=vae,
|
||||
latent_upsampler=latent_upsampler,
|
||||
)
|
||||
pipe.save_pretrained(
|
||||
(output_path / "ltx_pipeline").as_posix(),
|
||||
safe_serialization=True,
|
||||
variant=variant,
|
||||
max_shard_size="5GB",
|
||||
)
|
||||
pipe_upsample.save_pretrained(
|
||||
(output_path / "ltx_upsample_pipeline").as_posix(),
|
||||
safe_serialization=True,
|
||||
variant=variant,
|
||||
max_shard_size="5GB",
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported version: {args.version}")
|
||||
|
||||
@@ -148,6 +148,7 @@ else:
|
||||
"AutoencoderKL",
|
||||
"AutoencoderKLAllegro",
|
||||
"AutoencoderKLCogVideoX",
|
||||
"AutoencoderKLCosmos",
|
||||
"AutoencoderKLHunyuanVideo",
|
||||
"AutoencoderKLLTXVideo",
|
||||
"AutoencoderKLMagvit",
|
||||
@@ -158,6 +159,7 @@ else:
|
||||
"AutoencoderTiny",
|
||||
"AutoModel",
|
||||
"CacheMixin",
|
||||
"ChromaTransformer2DModel",
|
||||
"CogVideoXTransformer3DModel",
|
||||
"CogView3PlusTransformer2DModel",
|
||||
"CogView4Transformer2DModel",
|
||||
@@ -166,6 +168,7 @@ else:
|
||||
"ControlNetModel",
|
||||
"ControlNetUnionModel",
|
||||
"ControlNetXSAdapter",
|
||||
"CosmosTransformer3DModel",
|
||||
"DiTTransformer2DModel",
|
||||
"EasyAnimateTransformer3DModel",
|
||||
"FluxControlNetModel",
|
||||
@@ -175,6 +178,7 @@ else:
|
||||
"HunyuanDiT2DControlNetModel",
|
||||
"HunyuanDiT2DModel",
|
||||
"HunyuanDiT2DMultiControlNetModel",
|
||||
"HunyuanVideoFramepackTransformer3DModel",
|
||||
"HunyuanVideoTransformer3DModel",
|
||||
"I2VGenXLUNet",
|
||||
"Kandinsky3UNet",
|
||||
@@ -356,6 +360,9 @@ else:
|
||||
"CogView3PlusPipeline",
|
||||
"CogView4ControlPipeline",
|
||||
"CogView4Pipeline",
|
||||
"ConsisIDPipeline",
|
||||
"CosmosTextToWorldPipeline",
|
||||
"CosmosVideoToWorldPipeline",
|
||||
"CycleDiffusionPipeline",
|
||||
"EasyAnimateControlPipeline",
|
||||
"EasyAnimateInpaintPipeline",
|
||||
@@ -376,6 +383,7 @@ else:
|
||||
"HunyuanDiTPAGPipeline",
|
||||
"HunyuanDiTPipeline",
|
||||
"HunyuanSkyreelsImageToVideoPipeline",
|
||||
"HunyuanVideoFramepackPipeline",
|
||||
"HunyuanVideoImageToVideoPipeline",
|
||||
"HunyuanVideoPipeline",
|
||||
"I2VGenXLPipeline",
|
||||
@@ -413,6 +421,7 @@ else:
|
||||
"LEditsPPPipelineStableDiffusionXL",
|
||||
"LTXConditionPipeline",
|
||||
"LTXImageToVideoPipeline",
|
||||
"LTXLatentUpsamplePipeline",
|
||||
"LTXPipeline",
|
||||
"Lumina2Pipeline",
|
||||
"Lumina2Text2ImgPipeline",
|
||||
@@ -513,6 +522,8 @@ else:
|
||||
"VersatileDiffusionPipeline",
|
||||
"VersatileDiffusionTextToImagePipeline",
|
||||
"VideoToVideoSDPipeline",
|
||||
"VisualClozeGenerationPipeline",
|
||||
"VisualClozePipeline",
|
||||
"VQDiffusionPipeline",
|
||||
"WanImageToVideoPipeline",
|
||||
"WanPipeline",
|
||||
@@ -743,6 +754,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKL,
|
||||
AutoencoderKLAllegro,
|
||||
AutoencoderKLCogVideoX,
|
||||
AutoencoderKLCosmos,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLMagvit,
|
||||
@@ -761,6 +773,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ControlNetModel,
|
||||
ControlNetUnionModel,
|
||||
ControlNetXSAdapter,
|
||||
CosmosTransformer3DModel,
|
||||
DiTTransformer2DModel,
|
||||
EasyAnimateTransformer3DModel,
|
||||
FluxControlNetModel,
|
||||
@@ -770,6 +783,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
HunyuanDiT2DControlNetModel,
|
||||
HunyuanDiT2DModel,
|
||||
HunyuanDiT2DMultiControlNetModel,
|
||||
HunyuanVideoFramepackTransformer3DModel,
|
||||
HunyuanVideoTransformer3DModel,
|
||||
I2VGenXLUNet,
|
||||
Kandinsky3UNet,
|
||||
@@ -930,6 +944,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
CogView3PlusPipeline,
|
||||
CogView4ControlPipeline,
|
||||
CogView4Pipeline,
|
||||
ConsisIDPipeline,
|
||||
CosmosTextToWorldPipeline,
|
||||
CosmosVideoToWorldPipeline,
|
||||
CycleDiffusionPipeline,
|
||||
EasyAnimateControlPipeline,
|
||||
EasyAnimateInpaintPipeline,
|
||||
@@ -950,6 +967,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
HunyuanDiTPAGPipeline,
|
||||
HunyuanDiTPipeline,
|
||||
HunyuanSkyreelsImageToVideoPipeline,
|
||||
HunyuanVideoFramepackPipeline,
|
||||
HunyuanVideoImageToVideoPipeline,
|
||||
HunyuanVideoPipeline,
|
||||
I2VGenXLPipeline,
|
||||
@@ -987,6 +1005,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
LTXConditionPipeline,
|
||||
LTXImageToVideoPipeline,
|
||||
LTXLatentUpsamplePipeline,
|
||||
LTXPipeline,
|
||||
Lumina2Pipeline,
|
||||
Lumina2Text2ImgPipeline,
|
||||
@@ -1086,6 +1105,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
VersatileDiffusionPipeline,
|
||||
VersatileDiffusionTextToImagePipeline,
|
||||
VideoToVideoSDPipeline,
|
||||
VisualClozeGenerationPipeline,
|
||||
VisualClozePipeline,
|
||||
VQDiffusionPipeline,
|
||||
WanImageToVideoPipeline,
|
||||
WanPipeline,
|
||||
|
||||
@@ -348,7 +348,7 @@ def _load_lora_into_text_encoder(
|
||||
|
||||
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||
if prefix is not None:
|
||||
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
|
||||
state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
|
||||
|
||||
if len(state_dict) > 0:
|
||||
logger.info(f"Loading {prefix}.")
|
||||
@@ -374,7 +374,7 @@ def _load_lora_into_text_encoder(
|
||||
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
|
||||
network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
|
||||
|
||||
|
||||
@@ -727,8 +727,25 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
||||
elif k.startswith("lora_te1_"):
|
||||
has_te_keys = True
|
||||
continue
|
||||
elif k.startswith("lora_transformer_context_embedder"):
|
||||
diffusers_key = "context_embedder"
|
||||
elif k.startswith("lora_transformer_norm_out_linear"):
|
||||
diffusers_key = "norm_out.linear"
|
||||
elif k.startswith("lora_transformer_proj_out"):
|
||||
diffusers_key = "proj_out"
|
||||
elif k.startswith("lora_transformer_x_embedder"):
|
||||
diffusers_key = "x_embedder"
|
||||
elif k.startswith("lora_transformer_time_text_embed_guidance_embedder_linear_"):
|
||||
i = int(k.split("lora_transformer_time_text_embed_guidance_embedder_linear_")[-1])
|
||||
diffusers_key = f"time_text_embed.guidance_embedder.linear_{i}"
|
||||
elif k.startswith("lora_transformer_time_text_embed_text_embedder_linear_"):
|
||||
i = int(k.split("lora_transformer_time_text_embed_text_embedder_linear_")[-1])
|
||||
diffusers_key = f"time_text_embed.text_embedder.linear_{i}"
|
||||
elif k.startswith("lora_transformer_time_text_embed_timestep_embedder_linear_"):
|
||||
i = int(k.split("lora_transformer_time_text_embed_timestep_embedder_linear_")[-1])
|
||||
diffusers_key = f"time_text_embed.timestep_embedder.linear_{i}"
|
||||
else:
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError(f"Handling for key ({k}) is not implemented.")
|
||||
|
||||
if "attn_" in k:
|
||||
if "_to_out_0" in k:
|
||||
@@ -1687,3 +1704,11 @@ def _convert_musubi_wan_lora_to_diffusers(state_dict):
|
||||
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def _convert_non_diffusers_hidream_lora_to_diffusers(state_dict, non_diffusers_prefix="diffusion_model"):
|
||||
if not all(k.startswith(non_diffusers_prefix) for k in state_dict):
|
||||
raise ValueError("Invalid LoRA state dict for HiDream.")
|
||||
converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
|
||||
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
|
||||
return converted_state_dict
|
||||
|
||||
@@ -43,6 +43,7 @@ from .lora_conversion_utils import (
|
||||
_convert_hunyuan_video_lora_to_diffusers,
|
||||
_convert_kohya_flux_lora_to_diffusers,
|
||||
_convert_musubi_wan_lora_to_diffusers,
|
||||
_convert_non_diffusers_hidream_lora_to_diffusers,
|
||||
_convert_non_diffusers_lora_to_diffusers,
|
||||
_convert_non_diffusers_lumina2_lora_to_diffusers,
|
||||
_convert_non_diffusers_wan_lora_to_diffusers,
|
||||
@@ -2103,7 +2104,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
prefix = prefix or cls.transformer_name
|
||||
for key in list(state_dict.keys()):
|
||||
if key.split(".")[0] == prefix:
|
||||
state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key)
|
||||
state_dict[key.removeprefix(f"{prefix}.")] = state_dict.pop(key)
|
||||
|
||||
# Find invalid keys
|
||||
transformer_state_dict = transformer.state_dict()
|
||||
@@ -2425,7 +2426,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
prefix = prefix or cls.transformer_name
|
||||
for key in list(state_dict.keys()):
|
||||
if key.split(".")[0] == prefix:
|
||||
state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key)
|
||||
state_dict[key.removeprefix(f"{prefix}.")] = state_dict.pop(key)
|
||||
|
||||
# Expand transformer parameter shapes if they don't match lora
|
||||
has_param_with_shape_update = False
|
||||
@@ -5371,7 +5372,6 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
@@ -5465,6 +5465,10 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
is_non_diffusers_format = any("diffusion_model" in k for k in state_dict)
|
||||
if is_non_diffusers_format:
|
||||
state_dict = _convert_non_diffusers_hidream_lora_to_diffusers(state_dict)
|
||||
|
||||
return state_dict
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
||||
|
||||
@@ -57,6 +57,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
|
||||
"WanTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"CogView4Transformer2DModel": lambda model_cls, weights: weights,
|
||||
"HiDreamImageTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights,
|
||||
}
|
||||
|
||||
|
||||
@@ -230,7 +231,7 @@ class PeftAdapterMixin:
|
||||
raise ValueError("`network_alphas` cannot be None when `prefix` is None.")
|
||||
|
||||
if prefix is not None:
|
||||
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
|
||||
state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
|
||||
|
||||
if len(state_dict) > 0:
|
||||
if adapter_name in getattr(self, "peft_config", {}) and not hotswap:
|
||||
@@ -261,7 +262,9 @@ class PeftAdapterMixin:
|
||||
|
||||
if network_alphas is not None and len(network_alphas) >= 1:
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
|
||||
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
|
||||
network_alphas = {
|
||||
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
|
||||
}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
|
||||
_maybe_raise_error_for_ambiguity(lora_config_kwargs)
|
||||
|
||||
@@ -29,8 +29,10 @@ from .single_file_utils import (
|
||||
convert_animatediff_checkpoint_to_diffusers,
|
||||
convert_auraflow_transformer_checkpoint_to_diffusers,
|
||||
convert_autoencoder_dc_checkpoint_to_diffusers,
|
||||
convert_chroma_transformer_to_diffusers,
|
||||
convert_controlnet_checkpoint,
|
||||
convert_flux_transformer_checkpoint_to_diffusers,
|
||||
convert_hidream_transformer_to_diffusers,
|
||||
convert_hunyuan_video_transformer_to_diffusers,
|
||||
convert_ldm_unet_checkpoint,
|
||||
convert_ldm_vae_checkpoint,
|
||||
@@ -133,6 +135,14 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
|
||||
"default_subfolder": "vae",
|
||||
},
|
||||
"HiDreamImageTransformer2DModel": {
|
||||
"checkpoint_mapping_fn": convert_hidream_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"ChromaTransformer2DModel": {
|
||||
"checkpoint_mapping_fn": convert_chroma_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -187,9 +197,8 @@ class FromOriginalModelMixin:
|
||||
original_config (`str`, *optional*):
|
||||
Dict or path to a yaml file containing the configuration for the model in its original format.
|
||||
If a dict is provided, it will be used to initialize the model configuration.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
torch_dtype (`torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
@@ -126,6 +126,7 @@ CHECKPOINT_KEY_NAMES = {
|
||||
],
|
||||
"wan": ["model.diffusion_model.head.modulation", "head.modulation"],
|
||||
"wan_vae": "decoder.middle.0.residual.0.gamma",
|
||||
"hidream": "double_stream_blocks.0.block.adaLN_modulation.1.bias",
|
||||
}
|
||||
|
||||
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
@@ -190,6 +191,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
|
||||
"wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
|
||||
"wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
|
||||
"hidream": {"pretrained_model_name_or_path": "HiDream-ai/HiDream-I1-Dev"},
|
||||
}
|
||||
|
||||
# Use to configure model sample size when original config is provided
|
||||
@@ -701,6 +703,8 @@ def infer_diffusers_model_type(checkpoint):
|
||||
elif CHECKPOINT_KEY_NAMES["wan_vae"] in checkpoint:
|
||||
# All Wan models use the same VAE so we can use the same default model repo to fetch the config
|
||||
model_type = "wan-t2v-14B"
|
||||
elif CHECKPOINT_KEY_NAMES["hidream"] in checkpoint:
|
||||
model_type = "hidream"
|
||||
else:
|
||||
model_type = "v1"
|
||||
|
||||
@@ -2195,7 +2199,6 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
for i in range(num_layers):
|
||||
block_prefix = f"transformer_blocks.{i}."
|
||||
# norms.
|
||||
## norm1
|
||||
converted_state_dict[f"{block_prefix}norm1.linear.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.img_mod.lin.weight"
|
||||
)
|
||||
@@ -2281,6 +2284,7 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
# single transformer blocks
|
||||
for i in range(num_single_layers):
|
||||
block_prefix = f"single_transformer_blocks.{i}."
|
||||
|
||||
# norm.linear <- single_blocks.0.modulation.lin
|
||||
converted_state_dict[f"{block_prefix}norm.linear.weight"] = checkpoint.pop(
|
||||
f"single_blocks.{i}.modulation.lin.weight"
|
||||
@@ -2316,6 +2320,7 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
|
||||
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
|
||||
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
|
||||
|
||||
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
|
||||
checkpoint.pop("final_layer.adaLN_modulation.1.weight")
|
||||
)
|
||||
@@ -3293,3 +3298,160 @@ def convert_wan_vae_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict[key] = value
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def convert_hidream_transformer_to_diffusers(checkpoint, **kwargs):
|
||||
keys = list(checkpoint.keys())
|
||||
for k in keys:
|
||||
if "model.diffusion_model." in k:
|
||||
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
|
||||
|
||||
return checkpoint
|
||||
|
||||
|
||||
def convert_chroma_transformer_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict = {}
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
for k in keys:
|
||||
if "model.diffusion_model." in k:
|
||||
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
|
||||
|
||||
for k in keys:
|
||||
if k.startswith("distilled_guidance_layer.norms"):
|
||||
converted_state_dict[k.replace(".scale", ".weight")] = checkpoint.pop(k)
|
||||
elif k.startswith("distilled_guidance_layer.layer"):
|
||||
converted_state_dict[k.replace("in_layer", "linear_1").replace("out_layer", "linear_2")] = checkpoint.pop(
|
||||
k
|
||||
)
|
||||
elif k.startswith("distilled_guidance_layer"):
|
||||
converted_state_dict[k] = checkpoint.pop(k)
|
||||
|
||||
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401
|
||||
num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401
|
||||
mlp_ratio = 4.0
|
||||
inner_dim = 3072
|
||||
|
||||
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
|
||||
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
|
||||
def swap_scale_shift(weight):
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
return new_weight
|
||||
|
||||
# context_embedder
|
||||
converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight")
|
||||
converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias")
|
||||
|
||||
# x_embedder
|
||||
converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight")
|
||||
converted_state_dict["x_embedder.bias"] = checkpoint.pop("img_in.bias")
|
||||
|
||||
# double transformer blocks
|
||||
for i in range(num_layers):
|
||||
block_prefix = f"transformer_blocks.{i}."
|
||||
# norms.
|
||||
|
||||
# Q, K, V
|
||||
sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0)
|
||||
context_q, context_k, context_v = torch.chunk(
|
||||
checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0
|
||||
)
|
||||
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
|
||||
checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0
|
||||
)
|
||||
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
|
||||
checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q])
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k])
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v])
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q])
|
||||
converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k])
|
||||
converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v])
|
||||
converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias])
|
||||
# qk_norm
|
||||
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.img_attn.norm.query_norm.scale"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.img_attn.norm.key_norm.scale"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
|
||||
)
|
||||
# ff img_mlp
|
||||
converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.img_mlp.0.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.0.bias")
|
||||
converted_state_dict[f"{block_prefix}ff.net.2.weight"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.weight")
|
||||
converted_state_dict[f"{block_prefix}ff.net.2.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.bias")
|
||||
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.txt_mlp.0.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.txt_mlp.0.bias"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.txt_mlp.2.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.txt_mlp.2.bias"
|
||||
)
|
||||
# output projections.
|
||||
converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.img_attn.proj.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.img_attn.proj.bias"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.txt_attn.proj.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.txt_attn.proj.bias"
|
||||
)
|
||||
|
||||
# single transformer blocks
|
||||
for i in range(num_single_layers):
|
||||
block_prefix = f"single_transformer_blocks.{i}."
|
||||
|
||||
# Q, K, V, mlp
|
||||
mlp_hidden_dim = int(inner_dim * mlp_ratio)
|
||||
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
|
||||
q, k, v, mlp = torch.split(checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
|
||||
q_bias, k_bias, v_bias, mlp_bias = torch.split(
|
||||
checkpoint.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q])
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k])
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v])
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias])
|
||||
converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp])
|
||||
converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias])
|
||||
# qk norm
|
||||
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
|
||||
f"single_blocks.{i}.norm.query_norm.scale"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
|
||||
f"single_blocks.{i}.norm.key_norm.scale"
|
||||
)
|
||||
# output projections.
|
||||
converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight")
|
||||
converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias")
|
||||
|
||||
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
|
||||
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
@@ -32,6 +32,7 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
|
||||
_import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"]
|
||||
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
|
||||
_import_structure["autoencoders.autoencoder_kl_cosmos"] = ["AutoencoderKLCosmos"]
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"]
|
||||
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
|
||||
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
|
||||
@@ -73,12 +74,15 @@ if is_torch_available():
|
||||
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
|
||||
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_cosmos"] = ["CosmosTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
|
||||
@@ -113,6 +117,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKL,
|
||||
AutoencoderKLAllegro,
|
||||
AutoencoderKLCogVideoX,
|
||||
AutoencoderKLCosmos,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLMagvit,
|
||||
@@ -146,16 +151,19 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .transformers import (
|
||||
AllegroTransformer3DModel,
|
||||
AuraFlowTransformer2DModel,
|
||||
ChromaTransformer2DModel,
|
||||
CogVideoXTransformer3DModel,
|
||||
CogView3PlusTransformer2DModel,
|
||||
CogView4Transformer2DModel,
|
||||
ConsisIDTransformer3DModel,
|
||||
CosmosTransformer3DModel,
|
||||
DiTTransformer2DModel,
|
||||
DualTransformer2DModel,
|
||||
EasyAnimateTransformer3DModel,
|
||||
FluxTransformer2DModel,
|
||||
HiDreamImageTransformer2DModel,
|
||||
HunyuanDiT2DModel,
|
||||
HunyuanVideoFramepackTransformer3DModel,
|
||||
HunyuanVideoTransformer3DModel,
|
||||
LatteTransformer3DModel,
|
||||
LTXVideoTransformer3DModel,
|
||||
|
||||
@@ -161,9 +161,8 @@ class MultiAdapter(ModelMixin):
|
||||
pretrained_model_path (`os.PathLike`):
|
||||
A path to a *directory* containing model weights saved using
|
||||
[`~diffusers.models.adapter.MultiAdapter.save_pretrained`], e.g., `./my_model_directory/adapter`.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
|
||||
will be automatically derived from the model's weights.
|
||||
torch_dtype (`torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model under this dtype.
|
||||
output_loading_info(`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
|
||||
@@ -203,8 +203,8 @@ class Attention(nn.Module):
|
||||
self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps)
|
||||
self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps)
|
||||
elif qk_norm == "rms_norm":
|
||||
self.norm_q = RMSNorm(dim_head, eps=eps)
|
||||
self.norm_k = RMSNorm(dim_head, eps=eps)
|
||||
self.norm_q = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.norm_k = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
elif qk_norm == "rms_norm_across_heads":
|
||||
# LTX applies qk norm across all heads
|
||||
self.norm_q = RMSNorm(dim_head * heads, eps=eps)
|
||||
|
||||
@@ -52,9 +52,8 @@ class AutoModel(ConfigMixin):
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
torch_dtype (`torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
@@ -3,6 +3,7 @@ from .autoencoder_dc import AutoencoderDC
|
||||
from .autoencoder_kl import AutoencoderKL
|
||||
from .autoencoder_kl_allegro import AutoencoderKLAllegro
|
||||
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
|
||||
from .autoencoder_kl_cosmos import AutoencoderKLCosmos
|
||||
from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
|
||||
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
|
||||
from .autoencoder_kl_magvit import AutoencoderKLMagvit
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -744,6 +744,17 @@ class DiagonalGaussianDistribution(object):
|
||||
return self.mean
|
||||
|
||||
|
||||
class IdentityDistribution(object):
|
||||
def __init__(self, parameters: torch.Tensor):
|
||||
self.parameters = parameters
|
||||
|
||||
def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
|
||||
return self.parameters
|
||||
|
||||
def mode(self) -> torch.Tensor:
|
||||
return self.parameters
|
||||
|
||||
|
||||
class EncoderTiny(nn.Module):
|
||||
r"""
|
||||
The `EncoderTiny` layer is a simpler version of the `Encoder` layer.
|
||||
|
||||
@@ -130,9 +130,8 @@ class MultiControlNetModel(ModelMixin):
|
||||
A path to a *directory* containing model weights saved using
|
||||
[`~models.controlnets.multicontrolnet.MultiControlNetModel.save_pretrained`], e.g.,
|
||||
`./my_model_directory/controlnet`.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
|
||||
will be automatically derived from the model's weights.
|
||||
torch_dtype (`torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model under this dtype.
|
||||
output_loading_info(`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
|
||||
@@ -143,9 +143,8 @@ class MultiControlNetUnionModel(ModelMixin):
|
||||
A path to a *directory* containing model weights saved using
|
||||
[`~models.controlnets.multicontrolnet.MultiControlNetUnionModel.save_pretrained`], e.g.,
|
||||
`./my_model_directory/controlnet`.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
|
||||
will be automatically derived from the model's weights.
|
||||
torch_dtype (`torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model under this dtype.
|
||||
output_loading_info(`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
|
||||
@@ -31,7 +31,7 @@ def get_timestep_embedding(
|
||||
downscale_freq_shift: float = 1,
|
||||
scale: float = 1,
|
||||
max_period: int = 10000,
|
||||
):
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
||||
|
||||
@@ -1204,7 +1204,7 @@ def apply_rotary_emb(
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
elif use_real_unbind_dim == -2:
|
||||
# Used for Stable Audio, OmniGen and CogView4
|
||||
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
|
||||
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
||||
else:
|
||||
@@ -1327,7 +1327,7 @@ class Timesteps(nn.Module):
|
||||
self.downscale_freq_shift = downscale_freq_shift
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, timesteps):
|
||||
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
t_emb = get_timestep_embedding(
|
||||
timesteps,
|
||||
self.num_channels,
|
||||
|
||||
@@ -787,9 +787,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
torch_dtype (`torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
@@ -171,6 +171,46 @@ class AdaLayerNormZero(nn.Module):
|
||||
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
||||
|
||||
|
||||
class AdaLayerNormZeroPruned(nn.Module):
|
||||
r"""
|
||||
Norm layer adaptive layer norm zero (adaLN-Zero).
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
num_embeddings (`int`): The size of the embeddings dictionary.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, norm_type="layer_norm", bias=True):
|
||||
super().__init__()
|
||||
if num_embeddings is not None:
|
||||
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
|
||||
else:
|
||||
self.emb = None
|
||||
|
||||
if norm_type == "layer_norm":
|
||||
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
||||
elif norm_type == "fp32_layer_norm":
|
||||
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timestep: Optional[torch.Tensor] = None,
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
hidden_dtype: Optional[torch.dtype] = None,
|
||||
emb: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
if self.emb is not None:
|
||||
emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.squeeze(0).chunk(6, dim=0)
|
||||
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
||||
|
||||
|
||||
class AdaLayerNormZeroSingle(nn.Module):
|
||||
r"""
|
||||
Norm layer adaptive layer norm zero (adaLN-Zero).
|
||||
|
||||
@@ -17,12 +17,15 @@ if is_torch_available():
|
||||
from .t5_film_transformer import T5FilmDecoder
|
||||
from .transformer_2d import Transformer2DModel
|
||||
from .transformer_allegro import AllegroTransformer3DModel
|
||||
from .transformer_chroma import ChromaTransformer2DModel
|
||||
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
|
||||
from .transformer_cogview4 import CogView4Transformer2DModel
|
||||
from .transformer_cosmos import CosmosTransformer3DModel
|
||||
from .transformer_easyanimate import EasyAnimateTransformer3DModel
|
||||
from .transformer_flux import FluxTransformer2DModel
|
||||
from .transformer_hidream_image import HiDreamImageTransformer2DModel
|
||||
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
|
||||
from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel
|
||||
from .transformer_ltx import LTXVideoTransformer3DModel
|
||||
from .transformer_lumina2 import Lumina2Transformer2DModel
|
||||
from .transformer_mochi import MochiTransformer3DModel
|
||||
|
||||
@@ -0,0 +1,753 @@
|
||||
# Copyright 2025 Black Forest Labs, The HuggingFace Team and lodestone-rock. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.import_utils import is_torch_npu_available
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import (
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
FluxAttnProcessor2_0,
|
||||
FluxAttnProcessor2_0_NPU,
|
||||
FusedFluxAttnProcessor2_0,
|
||||
)
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import (
|
||||
CombinedTimestepLabelEmbeddings,
|
||||
FluxPosEmbed,
|
||||
PixArtAlphaTextProjection,
|
||||
Timesteps,
|
||||
get_timestep_embedding,
|
||||
)
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import FP32LayerNorm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class ChromaApproximator(nn.Module):
|
||||
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers: int = 5):
|
||||
super().__init__()
|
||||
self.in_proj = nn.Linear(in_dim, hidden_dim, bias=True)
|
||||
self.layers = nn.ModuleList(
|
||||
[PixArtAlphaTextProjection(hidden_dim, hidden_dim, act_fn="silu") for _ in range(n_layers)]
|
||||
)
|
||||
self.norms = nn.ModuleList([nn.RMSNorm(hidden_dim) for _ in range(n_layers)])
|
||||
self.out_proj = nn.Linear(hidden_dim, out_dim)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.in_proj(x)
|
||||
|
||||
for layer, norms in zip(self.layers, self.norms):
|
||||
x = x + layer(norms(x))
|
||||
|
||||
return self.out_proj(x)
|
||||
|
||||
|
||||
class ChromaTimestepEmbeddings(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_channels: int,
|
||||
out_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(num_channels=num_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.guidance_proj = Timesteps(num_channels=num_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
|
||||
self.register_buffer(
|
||||
"mod_proj",
|
||||
get_timestep_embedding(
|
||||
torch.arange(out_dim) * 1000,
|
||||
2 * num_channels,
|
||||
flip_sin_to_cos=True,
|
||||
downscale_freq_shift=0,
|
||||
),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def forward(self, timestep: torch.Tensor) -> torch.Tensor:
|
||||
mod_index_length = self.mod_proj.shape[0]
|
||||
|
||||
timesteps_proj = self.time_proj(timestep).to(dtype=timestep.dtype)
|
||||
guidance_proj = self.guidance_proj(torch.tensor([0])).to(dtype=timestep.dtype, device=timestep.device)
|
||||
|
||||
mod_proj = self.mod_proj.to(dtype=timesteps_proj.dtype, device=timesteps_proj.device)
|
||||
timestep_guidance = (
|
||||
torch.cat([timesteps_proj, guidance_proj], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1)
|
||||
)
|
||||
input_vec = torch.cat([timestep_guidance, mod_proj.unsqueeze(0)], dim=-1)
|
||||
|
||||
return input_vec
|
||||
|
||||
|
||||
class ChromaAdaLayerNormZeroSinglePruned(nn.Module):
|
||||
r"""
|
||||
Norm layer adaptive layer norm zero (adaLN-Zero).
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
num_embeddings (`int`): The size of the embeddings dictionary.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
|
||||
super().__init__()
|
||||
|
||||
if norm_type == "layer_norm":
|
||||
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
emb: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
shift_msa, scale_msa, gate_msa = emb.squeeze(0).chunk(3, dim=0)
|
||||
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||
return x, gate_msa
|
||||
|
||||
|
||||
class ChromaAdaLayerNormZeroPruned(nn.Module):
|
||||
r"""
|
||||
Norm layer adaptive layer norm zero (adaLN-Zero).
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
num_embeddings (`int`): The size of the embeddings dictionary.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, norm_type="layer_norm", bias=True):
|
||||
super().__init__()
|
||||
if num_embeddings is not None:
|
||||
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
|
||||
else:
|
||||
self.emb = None
|
||||
|
||||
if norm_type == "layer_norm":
|
||||
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
||||
elif norm_type == "fp32_layer_norm":
|
||||
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timestep: Optional[torch.Tensor] = None,
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
hidden_dtype: Optional[torch.dtype] = None,
|
||||
emb: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
if self.emb is not None:
|
||||
emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.squeeze(0).chunk(6, dim=0)
|
||||
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class ChromaSingleTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.norm = ChromaAdaLayerNormZeroSinglePruned(dim)
|
||||
|
||||
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
|
||||
self.act_mlp = nn.GELU(approximate="tanh")
|
||||
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
||||
|
||||
if is_torch_npu_available():
|
||||
deprecation_message = (
|
||||
"Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
|
||||
"should be set explicitly using the `set_attn_processor` method."
|
||||
)
|
||||
deprecate("npu_processor", "0.34.0", deprecation_message)
|
||||
processor = FluxAttnProcessor2_0_NPU()
|
||||
else:
|
||||
processor = FluxAttnProcessor2_0()
|
||||
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
bias=True,
|
||||
processor=processor,
|
||||
qk_norm="rms_norm",
|
||||
eps=1e-6,
|
||||
pre_only=True,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
||||
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||
attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
**joint_attention_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
||||
gate = gate.unsqueeze(1)
|
||||
hidden_states = gate * self.proj_out(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
if hidden_states.dtype == torch.float16:
|
||||
hidden_states = hidden_states.clip(-65504, 65504)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class ChromaTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
qk_norm: str = "rms_norm",
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = ChromaAdaLayerNormZeroPruned(dim)
|
||||
self.norm1_context = ChromaAdaLayerNormZeroPruned(dim)
|
||||
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
added_kv_proj_dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
context_pre_only=False,
|
||||
bias=True,
|
||||
processor=FluxAttnProcessor2_0(),
|
||||
qk_norm=qk_norm,
|
||||
eps=eps,
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
||||
|
||||
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
temb_img, temb_txt = temb[:, :6], temb[:, 6:]
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb_img)
|
||||
|
||||
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
||||
encoder_hidden_states, emb=temb_txt
|
||||
)
|
||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||
# Attention.
|
||||
attention_outputs = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
**joint_attention_kwargs,
|
||||
)
|
||||
|
||||
if len(attention_outputs) == 2:
|
||||
attn_output, context_attn_output = attention_outputs
|
||||
elif len(attention_outputs) == 3:
|
||||
attn_output, context_attn_output, ip_attn_output = attention_outputs
|
||||
|
||||
# Process attention outputs for the `hidden_states`.
|
||||
attn_output = gate_msa.unsqueeze(1) * attn_output
|
||||
hidden_states = hidden_states + attn_output
|
||||
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
||||
|
||||
hidden_states = hidden_states + ff_output
|
||||
if len(attention_outputs) == 3:
|
||||
hidden_states = hidden_states + ip_attn_output
|
||||
|
||||
# Process attention outputs for the `encoder_hidden_states`.
|
||||
|
||||
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
|
||||
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
||||
|
||||
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
||||
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
||||
|
||||
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
||||
if encoder_hidden_states.dtype == torch.float16:
|
||||
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
||||
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
|
||||
class ChromaAdaLayerNormContinuous(nn.Module):
|
||||
r"""
|
||||
Adaptive normalization layer with a norm layer (layer_norm or rms_norm).
|
||||
|
||||
Args:
|
||||
embedding_dim (`int`): Embedding dimension to use during projection.
|
||||
conditioning_embedding_dim (`int`): Dimension of the input condition.
|
||||
elementwise_affine (`bool`, defaults to `True`):
|
||||
Boolean flag to denote if affine transformation should be applied.
|
||||
eps (`float`, defaults to 1e-5): Epsilon factor.
|
||||
bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
|
||||
norm_type (`str`, defaults to `"layer_norm"`):
|
||||
Normalization layer to use. Values supported: "layer_norm", "rms_norm".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
conditioning_embedding_dim: int,
|
||||
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
|
||||
# because the output is immediately scaled and shifted by the projected conditioning embeddings.
|
||||
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
|
||||
# However, this is how it was implemented in the original code, and it's rather likely you should
|
||||
# set `elementwise_affine` to False.
|
||||
elementwise_affine=True,
|
||||
eps=1e-5,
|
||||
bias=True,
|
||||
norm_type="layer_norm",
|
||||
):
|
||||
super().__init__()
|
||||
if norm_type == "layer_norm":
|
||||
self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
||||
elif norm_type == "rms_norm":
|
||||
self.norm = nn.RMSNorm(embedding_dim, eps, elementwise_affine)
|
||||
else:
|
||||
raise ValueError(f"unknown norm_type {norm_type}")
|
||||
|
||||
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
|
||||
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
|
||||
shift, scale = torch.chunk(emb.squeeze(0).to(x.dtype), 2, dim=0)
|
||||
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
||||
return x
|
||||
|
||||
|
||||
class ChromaTransformer2DModel(
|
||||
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
|
||||
):
|
||||
"""
|
||||
The Transformer model based on Flux SCHNELL architecture.
|
||||
|
||||
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
||||
|
||||
Args:
|
||||
patch_size (`int`, defaults to `1`):
|
||||
Patch size to turn the input data into small patches.
|
||||
in_channels (`int`, defaults to `64`):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, *optional*, defaults to `None`):
|
||||
The number of channels in the output. If not specified, it defaults to `in_channels`.
|
||||
num_layers (`int`, defaults to `19`):
|
||||
The number of layers of dual stream DiT blocks to use.
|
||||
num_single_layers (`int`, defaults to `38`):
|
||||
The number of layers of single stream DiT blocks to use.
|
||||
attention_head_dim (`int`, defaults to `128`):
|
||||
The number of dimensions to use for each attention head.
|
||||
num_attention_heads (`int`, defaults to `24`):
|
||||
The number of attention heads to use.
|
||||
joint_attention_dim (`int`, defaults to `4096`):
|
||||
The number of dimensions to use for the joint attention (embedding/channel dimension of
|
||||
`encoder_hidden_states`).
|
||||
pooled_projection_dim (`int`, defaults to `768`):
|
||||
The number of dimensions to use for the pooled projection.
|
||||
guidance_embeds (`bool`, defaults to `False`):
|
||||
Whether to use guidance embeddings for guidance-distilled variant of the model.
|
||||
axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
|
||||
The dimensions to use for the rotary positional embeddings.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"]
|
||||
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 1,
|
||||
in_channels: int = 64,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 19,
|
||||
num_single_layers: int = 38,
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 24,
|
||||
joint_attention_dim: int = 4096,
|
||||
axes_dims_rope: Tuple[int, ...] = (16, 56, 56),
|
||||
approximator_in_factor: int = 16,
|
||||
approximator_hidden_dim: int = 5120,
|
||||
approximator_layers: int = 5,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels or in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
|
||||
|
||||
self.time_text_embed = ChromaTimestepEmbeddings(
|
||||
num_channels=approximator_in_factor, out_dim=3 * num_single_layers + 2 * 6 * num_layers + 2
|
||||
)
|
||||
self.distilled_guidance_layer = ChromaApproximator(
|
||||
in_dim=in_channels,
|
||||
out_dim=self.inner_dim,
|
||||
hidden_dim=approximator_hidden_dim,
|
||||
n_layers=approximator_layers,
|
||||
)
|
||||
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
|
||||
self.x_embedder = nn.Linear(in_channels, self.inner_dim)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
ChromaTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
ChromaSingleTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
)
|
||||
for _ in range(num_single_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm_out = ChromaAdaLayerNormContinuous(
|
||||
self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6
|
||||
)
|
||||
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
||||
are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
self.set_attn_processor(FusedFluxAttnProcessor2_0())
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
pooled_projections: torch.Tensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
img_ids: torch.Tensor = None,
|
||||
txt_ids: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
controlnet_block_samples=None,
|
||||
controlnet_single_block_samples=None,
|
||||
return_dict: bool = True,
|
||||
controlnet_blocks_repeat: bool = False,
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
The [`FluxTransformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
||||
from the embeddings of input conditions.
|
||||
timestep ( `torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
||||
A list of tensors that if specified are added to the residuals of transformer blocks.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
timestep = timestep.to(hidden_states.dtype) * 1000
|
||||
if guidance is not None:
|
||||
guidance = guidance.to(hidden_states.dtype) * 1000
|
||||
|
||||
input_vec = self.time_text_embed(timestep)
|
||||
pooled_temb = self.distilled_guidance_layer(input_vec)
|
||||
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
if txt_ids.ndim == 3:
|
||||
logger.warning(
|
||||
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
||||
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
||||
)
|
||||
txt_ids = txt_ids[0]
|
||||
if img_ids.ndim == 3:
|
||||
logger.warning(
|
||||
"Passing `img_ids` 3d torch.Tensor is deprecated."
|
||||
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
||||
)
|
||||
img_ids = img_ids[0]
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
|
||||
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
|
||||
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
|
||||
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
|
||||
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
img_offset = 3 * len(self.single_transformer_blocks)
|
||||
txt_offset = img_offset + 6 * len(self.transformer_blocks)
|
||||
img_modulation = img_offset + 6 * index_block
|
||||
text_modulation = txt_offset + 6 * index_block
|
||||
temb = torch.cat(
|
||||
(
|
||||
pooled_temb[:, img_modulation : img_modulation + 6],
|
||||
pooled_temb[:, text_modulation : text_modulation + 6],
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
)
|
||||
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
)
|
||||
|
||||
# controlnet residual
|
||||
if controlnet_block_samples is not None:
|
||||
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
|
||||
interval_control = int(np.ceil(interval_control))
|
||||
# For Xlabs ControlNet.
|
||||
if controlnet_blocks_repeat:
|
||||
hidden_states = (
|
||||
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
|
||||
)
|
||||
else:
|
||||
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
start_idx = 3 * index_block
|
||||
temb = pooled_temb[:, start_idx : start_idx + 3]
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
)
|
||||
|
||||
# controlnet residual
|
||||
if controlnet_single_block_samples is not None:
|
||||
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
|
||||
interval_control = int(np.ceil(interval_control))
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
||||
+ controlnet_single_block_samples[index_block // interval_control]
|
||||
)
|
||||
|
||||
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
||||
|
||||
temb = pooled_temb[:, -2:]
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -0,0 +1,555 @@
|
||||
# Copyright 2024 The NVIDIA Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import is_torchvision_available
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention
|
||||
from ..embeddings import Timesteps
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import RMSNorm
|
||||
|
||||
|
||||
if is_torchvision_available():
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
class CosmosPatchEmbed(nn.Module):
|
||||
def __init__(
|
||||
self, in_channels: int, out_channels: int, patch_size: Tuple[int, int, int], bias: bool = True
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.proj = nn.Linear(in_channels * patch_size[0] * patch_size[1] * patch_size[2], out_channels, bias=bias)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p_t, p_h, p_w = self.patch_size
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, num_channels, num_frames // p_t, p_t, height // p_h, p_h, width // p_w, p_w
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7)
|
||||
hidden_states = self.proj(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CosmosTimestepEmbedding(nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int) -> None:
|
||||
super().__init__()
|
||||
self.linear_1 = nn.Linear(in_features, out_features, bias=False)
|
||||
self.activation = nn.SiLU()
|
||||
self.linear_2 = nn.Linear(out_features, 3 * out_features, bias=False)
|
||||
|
||||
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
emb = self.linear_1(timesteps)
|
||||
emb = self.activation(emb)
|
||||
emb = self.linear_2(emb)
|
||||
return emb
|
||||
|
||||
|
||||
class CosmosEmbedding(nn.Module):
|
||||
def __init__(self, embedding_dim: int, condition_dim: int) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(embedding_dim, flip_sin_to_cos=True, downscale_freq_shift=0.0)
|
||||
self.t_embedder = CosmosTimestepEmbedding(embedding_dim, condition_dim)
|
||||
self.norm = RMSNorm(embedding_dim, eps=1e-6, elementwise_affine=True)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, timestep: torch.LongTensor) -> torch.Tensor:
|
||||
timesteps_proj = self.time_proj(timestep).type_as(hidden_states)
|
||||
temb = self.t_embedder(timesteps_proj)
|
||||
embedded_timestep = self.norm(timesteps_proj)
|
||||
return temb, embedded_timestep
|
||||
|
||||
|
||||
class CosmosAdaLayerNorm(nn.Module):
|
||||
def __init__(self, in_features: int, hidden_features: int) -> None:
|
||||
super().__init__()
|
||||
self.embedding_dim = in_features
|
||||
|
||||
self.activation = nn.SiLU()
|
||||
self.norm = nn.LayerNorm(in_features, elementwise_affine=False, eps=1e-6)
|
||||
self.linear_1 = nn.Linear(in_features, hidden_features, bias=False)
|
||||
self.linear_2 = nn.Linear(hidden_features, 2 * in_features, bias=False)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, embedded_timestep: torch.Tensor, temb: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
embedded_timestep = self.activation(embedded_timestep)
|
||||
embedded_timestep = self.linear_1(embedded_timestep)
|
||||
embedded_timestep = self.linear_2(embedded_timestep)
|
||||
|
||||
if temb is not None:
|
||||
embedded_timestep = embedded_timestep + temb[:, : 2 * self.embedding_dim]
|
||||
|
||||
shift, scale = embedded_timestep.chunk(2, dim=1)
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = hidden_states * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CosmosAdaLayerNormZero(nn.Module):
|
||||
def __init__(self, in_features: int, hidden_features: Optional[int] = None) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.norm = nn.LayerNorm(in_features, elementwise_affine=False, eps=1e-6)
|
||||
self.activation = nn.SiLU()
|
||||
|
||||
if hidden_features is None:
|
||||
self.linear_1 = nn.Identity()
|
||||
else:
|
||||
self.linear_1 = nn.Linear(in_features, hidden_features, bias=False)
|
||||
|
||||
self.linear_2 = nn.Linear(hidden_features, 3 * in_features, bias=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
embedded_timestep: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
embedded_timestep = self.activation(embedded_timestep)
|
||||
embedded_timestep = self.linear_1(embedded_timestep)
|
||||
embedded_timestep = self.linear_2(embedded_timestep)
|
||||
|
||||
if temb is not None:
|
||||
embedded_timestep = embedded_timestep + temb
|
||||
|
||||
shift, scale, gate = embedded_timestep.chunk(3, dim=1)
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = hidden_states * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
return hidden_states, gate
|
||||
|
||||
|
||||
class CosmosAttnProcessor2_0:
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("CosmosAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# 1. QKV projections
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
|
||||
# 2. QK normalization
|
||||
query = attn.norm_q(query)
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# 3. Apply RoPE
|
||||
if image_rotary_emb is not None:
|
||||
from ..embeddings import apply_rotary_emb
|
||||
|
||||
query = apply_rotary_emb(query, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
|
||||
key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
|
||||
|
||||
# 4. Prepare for GQA
|
||||
query_idx = torch.tensor(query.size(3), device=query.device)
|
||||
key_idx = torch.tensor(key.size(3), device=key.device)
|
||||
value_idx = torch.tensor(value.size(3), device=value.device)
|
||||
key = key.repeat_interleave(query_idx // key_idx, dim=3)
|
||||
value = value.repeat_interleave(query_idx // value_idx, dim=3)
|
||||
|
||||
# 5. Attention
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3).type_as(query)
|
||||
|
||||
# 6. Output projection
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CosmosTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
cross_attention_dim: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
adaln_lora_dim: int = 256,
|
||||
qk_norm: str = "rms_norm",
|
||||
out_bias: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
hidden_size = num_attention_heads * attention_head_dim
|
||||
|
||||
self.norm1 = CosmosAdaLayerNormZero(in_features=hidden_size, hidden_features=adaln_lora_dim)
|
||||
self.attn1 = Attention(
|
||||
query_dim=hidden_size,
|
||||
cross_attention_dim=None,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
qk_norm=qk_norm,
|
||||
elementwise_affine=True,
|
||||
out_bias=out_bias,
|
||||
processor=CosmosAttnProcessor2_0(),
|
||||
)
|
||||
|
||||
self.norm2 = CosmosAdaLayerNormZero(in_features=hidden_size, hidden_features=adaln_lora_dim)
|
||||
self.attn2 = Attention(
|
||||
query_dim=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
qk_norm=qk_norm,
|
||||
elementwise_affine=True,
|
||||
out_bias=out_bias,
|
||||
processor=CosmosAttnProcessor2_0(),
|
||||
)
|
||||
|
||||
self.norm3 = CosmosAdaLayerNormZero(in_features=hidden_size, hidden_features=adaln_lora_dim)
|
||||
self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu", bias=out_bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
embedded_timestep: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
extra_pos_emb: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if extra_pos_emb is not None:
|
||||
hidden_states = hidden_states + extra_pos_emb
|
||||
|
||||
# 1. Self Attention
|
||||
norm_hidden_states, gate = self.norm1(hidden_states, embedded_timestep, temb)
|
||||
attn_output = self.attn1(norm_hidden_states, image_rotary_emb=image_rotary_emb)
|
||||
hidden_states = hidden_states + gate.unsqueeze(1) * attn_output
|
||||
|
||||
# 2. Cross Attention
|
||||
norm_hidden_states, gate = self.norm2(hidden_states, embedded_timestep, temb)
|
||||
attn_output = self.attn2(
|
||||
norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
||||
)
|
||||
hidden_states = hidden_states + gate.unsqueeze(1) * attn_output
|
||||
|
||||
# 3. Feed Forward
|
||||
norm_hidden_states, gate = self.norm3(hidden_states, embedded_timestep, temb)
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
hidden_states = hidden_states + gate.unsqueeze(1) * ff_output
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CosmosRotaryPosEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
max_size: Tuple[int, int, int] = (128, 240, 240),
|
||||
patch_size: Tuple[int, int, int] = (1, 2, 2),
|
||||
base_fps: int = 24,
|
||||
rope_scale: Tuple[float, float, float] = (2.0, 1.0, 1.0),
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.max_size = [size // patch for size, patch in zip(max_size, patch_size)]
|
||||
self.patch_size = patch_size
|
||||
self.base_fps = base_fps
|
||||
|
||||
self.dim_h = hidden_size // 6 * 2
|
||||
self.dim_w = hidden_size // 6 * 2
|
||||
self.dim_t = hidden_size - self.dim_h - self.dim_w
|
||||
|
||||
self.h_ntk_factor = rope_scale[1] ** (self.dim_h / (self.dim_h - 2))
|
||||
self.w_ntk_factor = rope_scale[2] ** (self.dim_w / (self.dim_w - 2))
|
||||
self.t_ntk_factor = rope_scale[0] ** (self.dim_t / (self.dim_t - 2))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, fps: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
pe_size = [num_frames // self.patch_size[0], height // self.patch_size[1], width // self.patch_size[2]]
|
||||
device = hidden_states.device
|
||||
|
||||
h_theta = 10000.0 * self.h_ntk_factor
|
||||
w_theta = 10000.0 * self.w_ntk_factor
|
||||
t_theta = 10000.0 * self.t_ntk_factor
|
||||
|
||||
seq = torch.arange(max(self.max_size), device=device, dtype=torch.float32)
|
||||
dim_h_range = (
|
||||
torch.arange(0, self.dim_h, 2, device=device, dtype=torch.float32)[: (self.dim_h // 2)] / self.dim_h
|
||||
)
|
||||
dim_w_range = (
|
||||
torch.arange(0, self.dim_w, 2, device=device, dtype=torch.float32)[: (self.dim_w // 2)] / self.dim_w
|
||||
)
|
||||
dim_t_range = (
|
||||
torch.arange(0, self.dim_t, 2, device=device, dtype=torch.float32)[: (self.dim_t // 2)] / self.dim_t
|
||||
)
|
||||
h_spatial_freqs = 1.0 / (h_theta**dim_h_range)
|
||||
w_spatial_freqs = 1.0 / (w_theta**dim_w_range)
|
||||
temporal_freqs = 1.0 / (t_theta**dim_t_range)
|
||||
|
||||
emb_h = torch.outer(seq[: pe_size[1]], h_spatial_freqs)[None, :, None, :].repeat(pe_size[0], 1, pe_size[2], 1)
|
||||
emb_w = torch.outer(seq[: pe_size[2]], w_spatial_freqs)[None, None, :, :].repeat(pe_size[0], pe_size[1], 1, 1)
|
||||
|
||||
# Apply sequence scaling in temporal dimension
|
||||
if fps is None:
|
||||
# Images
|
||||
emb_t = torch.outer(seq[: pe_size[0]], temporal_freqs)
|
||||
else:
|
||||
# Videos
|
||||
emb_t = torch.outer(seq[: pe_size[0]] / fps * self.base_fps, temporal_freqs)
|
||||
|
||||
emb_t = emb_t[:, None, None, :].repeat(1, pe_size[1], pe_size[2], 1)
|
||||
freqs = torch.cat([emb_t, emb_h, emb_w] * 2, dim=-1).flatten(0, 2).float()
|
||||
cos = torch.cos(freqs)
|
||||
sin = torch.sin(freqs)
|
||||
return cos, sin
|
||||
|
||||
|
||||
class CosmosLearnablePositionalEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
max_size: Tuple[int, int, int],
|
||||
patch_size: Tuple[int, int, int],
|
||||
eps: float = 1e-6,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.max_size = [size // patch for size, patch in zip(max_size, patch_size)]
|
||||
self.patch_size = patch_size
|
||||
self.eps = eps
|
||||
|
||||
self.pos_emb_t = nn.Parameter(torch.zeros(self.max_size[0], hidden_size))
|
||||
self.pos_emb_h = nn.Parameter(torch.zeros(self.max_size[1], hidden_size))
|
||||
self.pos_emb_w = nn.Parameter(torch.zeros(self.max_size[2], hidden_size))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
pe_size = [num_frames // self.patch_size[0], height // self.patch_size[1], width // self.patch_size[2]]
|
||||
|
||||
emb_t = self.pos_emb_t[: pe_size[0]][None, :, None, None, :].repeat(batch_size, 1, pe_size[1], pe_size[2], 1)
|
||||
emb_h = self.pos_emb_h[: pe_size[1]][None, None, :, None, :].repeat(batch_size, pe_size[0], 1, pe_size[2], 1)
|
||||
emb_w = self.pos_emb_w[: pe_size[2]][None, None, None, :, :].repeat(batch_size, pe_size[0], pe_size[1], 1, 1)
|
||||
emb = emb_t + emb_h + emb_w
|
||||
emb = emb.flatten(1, 3)
|
||||
|
||||
norm = torch.linalg.vector_norm(emb, dim=-1, keepdim=True, dtype=torch.float32)
|
||||
norm = torch.add(self.eps, norm, alpha=np.sqrt(norm.numel() / emb.numel()))
|
||||
return (emb / norm).type_as(hidden_states)
|
||||
|
||||
|
||||
class CosmosTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A Transformer model for video-like data used in [Cosmos](https://github.com/NVIDIA/Cosmos).
|
||||
|
||||
Args:
|
||||
in_channels (`int`, defaults to `16`):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, defaults to `16`):
|
||||
The number of channels in the output.
|
||||
num_attention_heads (`int`, defaults to `32`):
|
||||
The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, defaults to `128`):
|
||||
The number of channels in each attention head.
|
||||
num_layers (`int`, defaults to `28`):
|
||||
The number of layers of transformer blocks to use.
|
||||
mlp_ratio (`float`, defaults to `4.0`):
|
||||
The ratio of the hidden layer size to the input size in the feedforward network.
|
||||
text_embed_dim (`int`, defaults to `4096`):
|
||||
Input dimension of text embeddings from the text encoder.
|
||||
adaln_lora_dim (`int`, defaults to `256`):
|
||||
The hidden dimension of the Adaptive LayerNorm LoRA layer.
|
||||
max_size (`Tuple[int, int, int]`, defaults to `(128, 240, 240)`):
|
||||
The maximum size of the input latent tensors in the temporal, height, and width dimensions.
|
||||
patch_size (`Tuple[int, int, int]`, defaults to `(1, 2, 2)`):
|
||||
The patch size to use for patchifying the input latent tensors in the temporal, height, and width
|
||||
dimensions.
|
||||
rope_scale (`Tuple[float, float, float]`, defaults to `(2.0, 1.0, 1.0)`):
|
||||
The scaling factor to use for RoPE in the temporal, height, and width dimensions.
|
||||
concat_padding_mask (`bool`, defaults to `True`):
|
||||
Whether to concatenate the padding mask to the input latent tensors.
|
||||
extra_pos_embed_type (`str`, *optional*, defaults to `learnable`):
|
||||
The type of extra positional embeddings to use. Can be one of `None` or `learnable`.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_skip_layerwise_casting_patterns = ["patch_embed", "final_layer", "norm"]
|
||||
_no_split_modules = ["CosmosTransformerBlock"]
|
||||
_keep_in_fp32_modules = ["learnable_pos_embed"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 16,
|
||||
out_channels: int = 16,
|
||||
num_attention_heads: int = 32,
|
||||
attention_head_dim: int = 128,
|
||||
num_layers: int = 28,
|
||||
mlp_ratio: float = 4.0,
|
||||
text_embed_dim: int = 1024,
|
||||
adaln_lora_dim: int = 256,
|
||||
max_size: Tuple[int, int, int] = (128, 240, 240),
|
||||
patch_size: Tuple[int, int, int] = (1, 2, 2),
|
||||
rope_scale: Tuple[float, float, float] = (2.0, 1.0, 1.0),
|
||||
concat_padding_mask: bool = True,
|
||||
extra_pos_embed_type: Optional[str] = "learnable",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
hidden_size = num_attention_heads * attention_head_dim
|
||||
|
||||
# 1. Patch Embedding
|
||||
patch_embed_in_channels = in_channels + 1 if concat_padding_mask else in_channels
|
||||
self.patch_embed = CosmosPatchEmbed(patch_embed_in_channels, hidden_size, patch_size, bias=False)
|
||||
|
||||
# 2. Positional Embedding
|
||||
self.rope = CosmosRotaryPosEmbed(
|
||||
hidden_size=attention_head_dim, max_size=max_size, patch_size=patch_size, rope_scale=rope_scale
|
||||
)
|
||||
|
||||
self.learnable_pos_embed = None
|
||||
if extra_pos_embed_type == "learnable":
|
||||
self.learnable_pos_embed = CosmosLearnablePositionalEmbed(
|
||||
hidden_size=hidden_size,
|
||||
max_size=max_size,
|
||||
patch_size=patch_size,
|
||||
)
|
||||
|
||||
# 3. Time Embedding
|
||||
self.time_embed = CosmosEmbedding(hidden_size, hidden_size)
|
||||
|
||||
# 4. Transformer Blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
CosmosTransformerBlock(
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
cross_attention_dim=text_embed_dim,
|
||||
mlp_ratio=mlp_ratio,
|
||||
adaln_lora_dim=adaln_lora_dim,
|
||||
qk_norm="rms_norm",
|
||||
out_bias=False,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 5. Output norm & projection
|
||||
self.norm_out = CosmosAdaLayerNorm(hidden_size, adaln_lora_dim)
|
||||
self.proj_out = nn.Linear(
|
||||
hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels, bias=False
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
fps: Optional[int] = None,
|
||||
condition_mask: Optional[torch.Tensor] = None,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
|
||||
# 1. Concatenate padding mask if needed & prepare attention mask
|
||||
if condition_mask is not None:
|
||||
hidden_states = torch.cat([hidden_states, condition_mask], dim=1)
|
||||
|
||||
if self.config.concat_padding_mask:
|
||||
padding_mask = transforms.functional.resize(
|
||||
padding_mask, list(hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
|
||||
)
|
||||
hidden_states = torch.cat(
|
||||
[hidden_states, padding_mask.unsqueeze(2).repeat(batch_size, 1, num_frames, 1, 1)], dim=1
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, S]
|
||||
|
||||
# 2. Generate positional embeddings
|
||||
image_rotary_emb = self.rope(hidden_states, fps=fps)
|
||||
extra_pos_emb = self.learnable_pos_embed(hidden_states) if self.config.extra_pos_embed_type else None
|
||||
|
||||
# 3. Patchify input
|
||||
p_t, p_h, p_w = self.config.patch_size
|
||||
post_patch_num_frames = num_frames // p_t
|
||||
post_patch_height = height // p_h
|
||||
post_patch_width = width // p_w
|
||||
hidden_states = self.patch_embed(hidden_states)
|
||||
hidden_states = hidden_states.flatten(1, 3) # [B, T, H, W, C] -> [B, THW, C]
|
||||
|
||||
# 4. Timestep embeddings
|
||||
temb, embedded_timestep = self.time_embed(hidden_states, timestep)
|
||||
|
||||
# 5. Transformer blocks
|
||||
for block in self.transformer_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
embedded_timestep,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
extra_pos_emb,
|
||||
attention_mask,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
embedded_timestep=embedded_timestep,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
extra_pos_emb=extra_pos_emb,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
# 6. Output norm & projection & unpatchify
|
||||
hidden_states = self.norm_out(hidden_states, embedded_timestep, temb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.unflatten(2, (p_h, p_w, p_t, -1))
|
||||
hidden_states = hidden_states.unflatten(1, (post_patch_num_frames, post_patch_height, post_patch_width))
|
||||
# Please just kill me at this point. What even is this permutation order and why is it different from the patching order?
|
||||
# Another few hours of sanity lost to the void.
|
||||
hidden_states = hidden_states.permute(0, 7, 1, 6, 2, 4, 3, 5)
|
||||
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if not return_dict:
|
||||
return (hidden_states,)
|
||||
|
||||
return Transformer2DModelOutput(sample=hidden_states)
|
||||
@@ -241,7 +241,7 @@ class FluxTransformer2DModel(
|
||||
joint_attention_dim: int = 4096,
|
||||
pooled_projection_dim: int = 768,
|
||||
guidance_embeds: bool = False,
|
||||
axes_dims_rope: Tuple[int] = (16, 56, 56),
|
||||
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels or in_channels
|
||||
@@ -447,8 +447,6 @@ class FluxTransformer2DModel(
|
||||
timestep = timestep.to(hidden_states.dtype) * 1000
|
||||
if guidance is not None:
|
||||
guidance = guidance.to(hidden_states.dtype) * 1000
|
||||
else:
|
||||
guidance = None
|
||||
|
||||
temb = (
|
||||
self.time_text_embed(timestep, pooled_projections)
|
||||
|
||||
@@ -5,7 +5,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...models.modeling_outputs import Transformer2DModelOutput
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
@@ -602,7 +602,7 @@ class HiDreamBlock(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["HiDreamImageTransformerBlock", "HiDreamImageSingleTransformerBlock"]
|
||||
|
||||
|
||||
@@ -0,0 +1,416 @@
|
||||
# Copyright 2025 The Framepack Team, The Hunyuan Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, get_logger, scale_lora_layers, unscale_lora_layers
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import get_1d_rotary_pos_embed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous
|
||||
from .transformer_hunyuan_video import (
|
||||
HunyuanVideoConditionEmbedding,
|
||||
HunyuanVideoPatchEmbed,
|
||||
HunyuanVideoSingleTransformerBlock,
|
||||
HunyuanVideoTokenRefiner,
|
||||
HunyuanVideoTransformerBlock,
|
||||
)
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class HunyuanVideoFramepackRotaryPosEmbed(nn.Module):
|
||||
def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], theta: float = 256.0) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.patch_size_t = patch_size_t
|
||||
self.rope_dim = rope_dim
|
||||
self.theta = theta
|
||||
|
||||
def forward(self, frame_indices: torch.Tensor, height: int, width: int, device: torch.device):
|
||||
height = height // self.patch_size
|
||||
width = width // self.patch_size
|
||||
grid = torch.meshgrid(
|
||||
frame_indices.to(device=device, dtype=torch.float32),
|
||||
torch.arange(0, height, device=device, dtype=torch.float32),
|
||||
torch.arange(0, width, device=device, dtype=torch.float32),
|
||||
indexing="ij",
|
||||
) # 3 * [W, H, T]
|
||||
grid = torch.stack(grid, dim=0) # [3, W, H, T]
|
||||
|
||||
freqs = []
|
||||
for i in range(3):
|
||||
freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True)
|
||||
freqs.append(freq)
|
||||
|
||||
freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2)
|
||||
freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2)
|
||||
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
class FramepackClipVisionProjection(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int):
|
||||
super().__init__()
|
||||
self.up = nn.Linear(in_channels, out_channels * 3)
|
||||
self.down = nn.Linear(out_channels * 3, out_channels)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.up(hidden_states)
|
||||
hidden_states = F.silu(hidden_states)
|
||||
hidden_states = self.down(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoHistoryPatchEmbed(nn.Module):
|
||||
def __init__(self, in_channels: int, inner_dim: int):
|
||||
super().__init__()
|
||||
self.proj = nn.Conv3d(in_channels, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
|
||||
self.proj_2x = nn.Conv3d(in_channels, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
|
||||
self.proj_4x = nn.Conv3d(in_channels, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
latents_clean: Optional[torch.Tensor] = None,
|
||||
latents_clean_2x: Optional[torch.Tensor] = None,
|
||||
latents_clean_4x: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if latents_clean is not None:
|
||||
latents_clean = self.proj(latents_clean)
|
||||
latents_clean = latents_clean.flatten(2).transpose(1, 2)
|
||||
if latents_clean_2x is not None:
|
||||
latents_clean_2x = _pad_for_3d_conv(latents_clean_2x, (2, 4, 4))
|
||||
latents_clean_2x = self.proj_2x(latents_clean_2x)
|
||||
latents_clean_2x = latents_clean_2x.flatten(2).transpose(1, 2)
|
||||
if latents_clean_4x is not None:
|
||||
latents_clean_4x = _pad_for_3d_conv(latents_clean_4x, (4, 8, 8))
|
||||
latents_clean_4x = self.proj_4x(latents_clean_4x)
|
||||
latents_clean_4x = latents_clean_4x.flatten(2).transpose(1, 2)
|
||||
return latents_clean, latents_clean_2x, latents_clean_4x
|
||||
|
||||
|
||||
class HunyuanVideoFramepackTransformer3DModel(
|
||||
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin
|
||||
):
|
||||
_supports_gradient_checkpointing = True
|
||||
_skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"]
|
||||
_no_split_modules = [
|
||||
"HunyuanVideoTransformerBlock",
|
||||
"HunyuanVideoSingleTransformerBlock",
|
||||
"HunyuanVideoHistoryPatchEmbed",
|
||||
"HunyuanVideoTokenRefiner",
|
||||
]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 16,
|
||||
out_channels: int = 16,
|
||||
num_attention_heads: int = 24,
|
||||
attention_head_dim: int = 128,
|
||||
num_layers: int = 20,
|
||||
num_single_layers: int = 40,
|
||||
num_refiner_layers: int = 2,
|
||||
mlp_ratio: float = 4.0,
|
||||
patch_size: int = 2,
|
||||
patch_size_t: int = 1,
|
||||
qk_norm: str = "rms_norm",
|
||||
guidance_embeds: bool = True,
|
||||
text_embed_dim: int = 4096,
|
||||
pooled_projection_dim: int = 768,
|
||||
rope_theta: float = 256.0,
|
||||
rope_axes_dim: Tuple[int] = (16, 56, 56),
|
||||
image_condition_type: Optional[str] = None,
|
||||
has_image_proj: int = False,
|
||||
image_proj_dim: int = 1152,
|
||||
has_clean_x_embedder: int = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
out_channels = out_channels or in_channels
|
||||
|
||||
# 1. Latent and condition embedders
|
||||
self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
|
||||
|
||||
# Framepack history projection embedder
|
||||
self.clean_x_embedder = None
|
||||
if has_clean_x_embedder:
|
||||
self.clean_x_embedder = HunyuanVideoHistoryPatchEmbed(in_channels, inner_dim)
|
||||
|
||||
self.context_embedder = HunyuanVideoTokenRefiner(
|
||||
text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
|
||||
)
|
||||
|
||||
# Framepack image-conditioning embedder
|
||||
self.image_projection = FramepackClipVisionProjection(image_proj_dim, inner_dim) if has_image_proj else None
|
||||
|
||||
self.time_text_embed = HunyuanVideoConditionEmbedding(
|
||||
inner_dim, pooled_projection_dim, guidance_embeds, image_condition_type
|
||||
)
|
||||
|
||||
# 2. RoPE
|
||||
self.rope = HunyuanVideoFramepackRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta)
|
||||
|
||||
# 3. Dual stream transformer blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
HunyuanVideoTransformerBlock(
|
||||
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 4. Single stream transformer blocks
|
||||
self.single_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
HunyuanVideoSingleTransformerBlock(
|
||||
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
||||
)
|
||||
for _ in range(num_single_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 5. Output projection
|
||||
self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_attention_mask: torch.Tensor,
|
||||
pooled_projections: torch.Tensor,
|
||||
image_embeds: torch.Tensor,
|
||||
indices_latents: torch.Tensor,
|
||||
guidance: Optional[torch.Tensor] = None,
|
||||
latents_clean: Optional[torch.Tensor] = None,
|
||||
indices_latents_clean: Optional[torch.Tensor] = None,
|
||||
latents_history_2x: Optional[torch.Tensor] = None,
|
||||
indices_latents_history_2x: Optional[torch.Tensor] = None,
|
||||
latents_history_4x: Optional[torch.Tensor] = None,
|
||||
indices_latents_history_4x: Optional[torch.Tensor] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p, p_t = self.config.patch_size, self.config.patch_size_t
|
||||
post_patch_num_frames = num_frames // p_t
|
||||
post_patch_height = height // p
|
||||
post_patch_width = width // p
|
||||
original_context_length = post_patch_num_frames * post_patch_height * post_patch_width
|
||||
|
||||
if indices_latents is None:
|
||||
indices_latents = torch.arange(0, num_frames).unsqueeze(0).expand(batch_size, -1)
|
||||
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
image_rotary_emb = self.rope(
|
||||
frame_indices=indices_latents, height=height, width=width, device=hidden_states.device
|
||||
)
|
||||
|
||||
latents_clean, latents_history_2x, latents_history_4x = self.clean_x_embedder(
|
||||
latents_clean, latents_history_2x, latents_history_4x
|
||||
)
|
||||
|
||||
if latents_clean is not None and indices_latents_clean is not None:
|
||||
image_rotary_emb_clean = self.rope(
|
||||
frame_indices=indices_latents_clean, height=height, width=width, device=hidden_states.device
|
||||
)
|
||||
if latents_history_2x is not None and indices_latents_history_2x is not None:
|
||||
image_rotary_emb_history_2x = self.rope(
|
||||
frame_indices=indices_latents_history_2x, height=height, width=width, device=hidden_states.device
|
||||
)
|
||||
if latents_history_4x is not None and indices_latents_history_4x is not None:
|
||||
image_rotary_emb_history_4x = self.rope(
|
||||
frame_indices=indices_latents_history_4x, height=height, width=width, device=hidden_states.device
|
||||
)
|
||||
|
||||
hidden_states, image_rotary_emb = self._pack_history_states(
|
||||
hidden_states,
|
||||
latents_clean,
|
||||
latents_history_2x,
|
||||
latents_history_4x,
|
||||
image_rotary_emb,
|
||||
image_rotary_emb_clean,
|
||||
image_rotary_emb_history_2x,
|
||||
image_rotary_emb_history_4x,
|
||||
post_patch_height,
|
||||
post_patch_width,
|
||||
)
|
||||
|
||||
temb, _ = self.time_text_embed(timestep, pooled_projections, guidance)
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
|
||||
|
||||
encoder_hidden_states_image = self.image_projection(image_embeds)
|
||||
attention_mask_image = encoder_attention_mask.new_ones((batch_size, encoder_hidden_states_image.shape[1]))
|
||||
|
||||
# must cat before (not after) encoder_hidden_states, due to attn masking
|
||||
encoder_hidden_states = torch.cat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
|
||||
encoder_attention_mask = torch.cat([attention_mask_image, encoder_attention_mask], dim=1)
|
||||
|
||||
latent_sequence_length = hidden_states.shape[1]
|
||||
condition_sequence_length = encoder_hidden_states.shape[1]
|
||||
sequence_length = latent_sequence_length + condition_sequence_length
|
||||
attention_mask = torch.zeros(
|
||||
batch_size, sequence_length, device=hidden_states.device, dtype=torch.bool
|
||||
) # [B, N]
|
||||
effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,]
|
||||
effective_sequence_length = latent_sequence_length + effective_condition_sequence_length
|
||||
|
||||
if batch_size == 1:
|
||||
encoder_hidden_states = encoder_hidden_states[:, : effective_condition_sequence_length[0]]
|
||||
attention_mask = None
|
||||
else:
|
||||
for i in range(batch_size):
|
||||
attention_mask[i, : effective_sequence_length[i]] = True
|
||||
# [B, 1, 1, N], for broadcasting across attention heads
|
||||
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
||||
block, hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
|
||||
)
|
||||
|
||||
for block in self.single_transformer_blocks:
|
||||
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
||||
block, hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
|
||||
)
|
||||
|
||||
else:
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
|
||||
)
|
||||
|
||||
for block in self.single_transformer_blocks:
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
|
||||
)
|
||||
|
||||
hidden_states = hidden_states[:, -original_context_length:]
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
||||
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (hidden_states,)
|
||||
return Transformer2DModelOutput(sample=hidden_states)
|
||||
|
||||
def _pack_history_states(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
latents_clean: Optional[torch.Tensor] = None,
|
||||
latents_history_2x: Optional[torch.Tensor] = None,
|
||||
latents_history_4x: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Tuple[torch.Tensor, torch.Tensor] = None,
|
||||
image_rotary_emb_clean: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
image_rotary_emb_history_2x: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
image_rotary_emb_history_4x: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
height: int = None,
|
||||
width: int = None,
|
||||
):
|
||||
image_rotary_emb = list(image_rotary_emb) # convert tuple to list for in-place modification
|
||||
|
||||
if latents_clean is not None and image_rotary_emb_clean is not None:
|
||||
hidden_states = torch.cat([latents_clean, hidden_states], dim=1)
|
||||
image_rotary_emb[0] = torch.cat([image_rotary_emb_clean[0], image_rotary_emb[0]], dim=0)
|
||||
image_rotary_emb[1] = torch.cat([image_rotary_emb_clean[1], image_rotary_emb[1]], dim=0)
|
||||
|
||||
if latents_history_2x is not None and image_rotary_emb_history_2x is not None:
|
||||
hidden_states = torch.cat([latents_history_2x, hidden_states], dim=1)
|
||||
image_rotary_emb_history_2x = self._pad_rotary_emb(image_rotary_emb_history_2x, height, width, (2, 2, 2))
|
||||
image_rotary_emb[0] = torch.cat([image_rotary_emb_history_2x[0], image_rotary_emb[0]], dim=0)
|
||||
image_rotary_emb[1] = torch.cat([image_rotary_emb_history_2x[1], image_rotary_emb[1]], dim=0)
|
||||
|
||||
if latents_history_4x is not None and image_rotary_emb_history_4x is not None:
|
||||
hidden_states = torch.cat([latents_history_4x, hidden_states], dim=1)
|
||||
image_rotary_emb_history_4x = self._pad_rotary_emb(image_rotary_emb_history_4x, height, width, (4, 4, 4))
|
||||
image_rotary_emb[0] = torch.cat([image_rotary_emb_history_4x[0], image_rotary_emb[0]], dim=0)
|
||||
image_rotary_emb[1] = torch.cat([image_rotary_emb_history_4x[1], image_rotary_emb[1]], dim=0)
|
||||
|
||||
return hidden_states, tuple(image_rotary_emb)
|
||||
|
||||
def _pad_rotary_emb(
|
||||
self,
|
||||
image_rotary_emb: Tuple[torch.Tensor],
|
||||
height: int,
|
||||
width: int,
|
||||
kernel_size: Tuple[int, int, int],
|
||||
):
|
||||
# freqs_cos, freqs_sin have shape [W * H * T, D / 2], where D is attention head dim
|
||||
freqs_cos, freqs_sin = image_rotary_emb
|
||||
freqs_cos = freqs_cos.unsqueeze(0).permute(0, 2, 1).unflatten(2, (-1, height, width))
|
||||
freqs_sin = freqs_sin.unsqueeze(0).permute(0, 2, 1).unflatten(2, (-1, height, width))
|
||||
freqs_cos = _pad_for_3d_conv(freqs_cos, kernel_size)
|
||||
freqs_sin = _pad_for_3d_conv(freqs_sin, kernel_size)
|
||||
freqs_cos = _center_down_sample_3d(freqs_cos, kernel_size)
|
||||
freqs_sin = _center_down_sample_3d(freqs_sin, kernel_size)
|
||||
freqs_cos = freqs_cos.flatten(2).permute(0, 2, 1).squeeze(0)
|
||||
freqs_sin = freqs_sin.flatten(2).permute(0, 2, 1).squeeze(0)
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
def _pad_for_3d_conv(x, kernel_size):
|
||||
if isinstance(x, (tuple, list)):
|
||||
return tuple(_pad_for_3d_conv(i, kernel_size) for i in x)
|
||||
b, c, t, h, w = x.shape
|
||||
pt, ph, pw = kernel_size
|
||||
pad_t = (pt - (t % pt)) % pt
|
||||
pad_h = (ph - (h % ph)) % ph
|
||||
pad_w = (pw - (w % pw)) % pw
|
||||
return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode="replicate")
|
||||
|
||||
|
||||
def _center_down_sample_3d(x, kernel_size):
|
||||
if isinstance(x, (tuple, list)):
|
||||
return tuple(_center_down_sample_3d(i, kernel_size) for i in x)
|
||||
return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size)
|
||||
@@ -156,6 +156,8 @@ else:
|
||||
]
|
||||
_import_structure["cogview3"] = ["CogView3PlusPipeline"]
|
||||
_import_structure["cogview4"] = ["CogView4Pipeline", "CogView4ControlPipeline"]
|
||||
_import_structure["consisid"] = ["ConsisIDPipeline"]
|
||||
_import_structure["cosmos"] = ["CosmosTextToWorldPipeline", "CosmosVideoToWorldPipeline"]
|
||||
_import_structure["controlnet"].extend(
|
||||
[
|
||||
"BlipDiffusionControlNetPipeline",
|
||||
@@ -227,6 +229,7 @@ else:
|
||||
"HunyuanVideoPipeline",
|
||||
"HunyuanSkyreelsImageToVideoPipeline",
|
||||
"HunyuanVideoImageToVideoPipeline",
|
||||
"HunyuanVideoFramepackPipeline",
|
||||
]
|
||||
_import_structure["kandinsky"] = [
|
||||
"KandinskyCombinedPipeline",
|
||||
@@ -265,7 +268,12 @@ else:
|
||||
]
|
||||
)
|
||||
_import_structure["latte"] = ["LattePipeline"]
|
||||
_import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline", "LTXConditionPipeline"]
|
||||
_import_structure["ltx"] = [
|
||||
"LTXPipeline",
|
||||
"LTXImageToVideoPipeline",
|
||||
"LTXConditionPipeline",
|
||||
"LTXLatentUpsamplePipeline",
|
||||
]
|
||||
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
|
||||
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
|
||||
_import_structure["marigold"].extend(
|
||||
@@ -278,6 +286,7 @@ else:
|
||||
_import_structure["mochi"] = ["MochiPipeline"]
|
||||
_import_structure["musicldm"] = ["MusicLDMPipeline"]
|
||||
_import_structure["omnigen"] = ["OmniGenPipeline"]
|
||||
_import_structure["visualcloze"] = ["VisualClozePipeline", "VisualClozeGenerationPipeline"]
|
||||
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
|
||||
_import_structure["pia"] = ["PIAPipeline"]
|
||||
_import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"]
|
||||
@@ -545,6 +554,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionControlNetXSPipeline,
|
||||
StableDiffusionXLControlNetXSPipeline,
|
||||
)
|
||||
from .cosmos import CosmosTextToWorldPipeline, CosmosVideoToWorldPipeline
|
||||
from .deepfloyd_if import (
|
||||
IFImg2ImgPipeline,
|
||||
IFImg2ImgSuperResolutionPipeline,
|
||||
@@ -589,6 +599,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .hidream_image import HiDreamImagePipeline
|
||||
from .hunyuan_video import (
|
||||
HunyuanSkyreelsImageToVideoPipeline,
|
||||
HunyuanVideoFramepackPipeline,
|
||||
HunyuanVideoImageToVideoPipeline,
|
||||
HunyuanVideoPipeline,
|
||||
)
|
||||
@@ -631,7 +642,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LEditsPPPipelineStableDiffusion,
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
)
|
||||
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXPipeline
|
||||
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline
|
||||
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
|
||||
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
|
||||
from .marigold import (
|
||||
@@ -722,6 +733,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
UniDiffuserPipeline,
|
||||
UniDiffuserTextDecoder,
|
||||
)
|
||||
from .visualcloze import VisualClozeGenerationPipeline, VisualClozePipeline
|
||||
from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline
|
||||
from .wuerstchen import (
|
||||
WuerstchenCombinedPipeline,
|
||||
|
||||
@@ -40,6 +40,7 @@ from ...utils import (
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ...utils.import_utils import is_transformers_version
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
|
||||
from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
|
||||
@@ -312,8 +313,19 @@ class AudioLDM2Pipeline(DiffusionPipeline):
|
||||
`inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
The sequence of generated hidden-states.
|
||||
"""
|
||||
cache_position_kwargs = {}
|
||||
if is_transformers_version("<", "4.52.0.dev0"):
|
||||
cache_position_kwargs["input_ids"] = inputs_embeds
|
||||
cache_position_kwargs["model_kwargs"] = model_kwargs
|
||||
else:
|
||||
cache_position_kwargs["seq_length"] = inputs_embeds.shape[0]
|
||||
cache_position_kwargs["device"] = (
|
||||
self.language_model.device if getattr(self, "language_model", None) is not None else self.device
|
||||
)
|
||||
cache_position_kwargs["model_kwargs"] = model_kwargs
|
||||
max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens
|
||||
model_kwargs = self.language_model._get_initial_cache_position(inputs_embeds, model_kwargs)
|
||||
model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs)
|
||||
|
||||
for _ in range(max_new_tokens):
|
||||
# prepare model inputs
|
||||
model_inputs = prepare_inputs_for_generation(inputs_embeds, **model_kwargs)
|
||||
|
||||
@@ -322,9 +322,8 @@ class AutoPipelineForText2Image(ConfigMixin):
|
||||
- A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights
|
||||
saved using
|
||||
[`~DiffusionPipeline.save_pretrained`].
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
torch_dtype (`torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
@@ -619,8 +618,7 @@ class AutoPipelineForImage2Image(ConfigMixin):
|
||||
saved using
|
||||
[`~DiffusionPipeline.save_pretrained`].
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
Override the default `torch.dtype` and load the model with another dtype.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
@@ -930,8 +928,7 @@ class AutoPipelineForInpainting(ConfigMixin):
|
||||
saved using
|
||||
[`~DiffusionPipeline.save_pretrained`].
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
Override the default `torch.dtype` and load the model with another dtype.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_cosmos_text2world"] = ["CosmosTextToWorldPipeline"]
|
||||
_import_structure["pipeline_cosmos_video2world"] = ["CosmosVideoToWorldPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_cosmos_text2world import CosmosTextToWorldPipeline
|
||||
from .pipeline_cosmos_video2world import CosmosVideoToWorldPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -0,0 +1,667 @@
|
||||
# Copyright 2024 The NVIDIA Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...models import AutoencoderKLCosmos, CosmosTransformer3DModel
|
||||
from ...schedulers import EDMEulerScheduler
|
||||
from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import CosmosPipelineOutput
|
||||
|
||||
|
||||
if is_cosmos_guardrail_available():
|
||||
from cosmos_guardrail import CosmosSafetyChecker
|
||||
else:
|
||||
|
||||
class CosmosSafetyChecker:
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise ImportError(
|
||||
"`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`."
|
||||
)
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import CosmosTextToWorldPipeline
|
||||
>>> from diffusers.utils import export_to_video
|
||||
|
||||
>>> model_id = "nvidia/Cosmos-1.0-Diffusion-7B-Text2World"
|
||||
>>> pipe = CosmosTextToWorldPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> prompt = "A sleek, humanoid robot stands in a vast warehouse filled with neatly stacked cardboard boxes on industrial shelves. The robot's metallic body gleams under the bright, even lighting, highlighting its futuristic design and intricate joints. A glowing blue light emanates from its chest, adding a touch of advanced technology. The background is dominated by rows of boxes, suggesting a highly organized storage system. The floor is lined with wooden pallets, enhancing the industrial setting. The camera remains static, capturing the robot's poised stance amidst the orderly environment, with a shallow depth of field that keeps the focus on the robot while subtly blurring the background for a cinematic effect."
|
||||
|
||||
>>> output = pipe(prompt=prompt).frames[0]
|
||||
>>> export_to_video(output, "output.mp4", fps=30)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class CosmosTextToWorldPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-video generation using [Cosmos](https://github.com/NVIDIA/Cosmos).
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
Args:
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
Frozen text-encoder. Cosmos uses
|
||||
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
|
||||
[t5-11b](https://huggingface.co/google-t5/t5-11b) variant.
|
||||
tokenizer (`T5TokenizerFast`):
|
||||
Tokenizer of class
|
||||
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
||||
transformer ([`CosmosTransformer3DModel`]):
|
||||
Conditional Transformer to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLCosmos`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
# We mark safety_checker as optional here to get around some test failures, but it is not really optional
|
||||
_optional_components = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_encoder: T5EncoderModel,
|
||||
tokenizer: T5TokenizerFast,
|
||||
transformer: CosmosTransformer3DModel,
|
||||
vae: AutoencoderKLCosmos,
|
||||
scheduler: EDMEulerScheduler,
|
||||
safety_checker: CosmosSafetyChecker = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if safety_checker is None:
|
||||
safety_checker = CosmosSafetyChecker()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
)
|
||||
|
||||
self.vae_scale_factor_temporal = (
|
||||
self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 8
|
||||
)
|
||||
self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 8
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
max_sequence_length: int = 512,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
return_length=True,
|
||||
return_offsets_mapping=False,
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
prompt_attention_mask = text_inputs.attention_mask.bool().to(device)
|
||||
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(device), attention_mask=prompt_attention_mask
|
||||
).last_hidden_state
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
lengths = prompt_attention_mask.sum(dim=1).cpu()
|
||||
for i, length in enumerate(lengths):
|
||||
prompt_embeds[i, length:] = 0
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 512,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier free guidance or not.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = negative_prompt_embeds.shape
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_channels_latents: 16,
|
||||
height: int = 704,
|
||||
width: int = 1280,
|
||||
num_frames: int = 121,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype) * self.scheduler.config.sigma_max
|
||||
|
||||
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
latent_height = height // self.vae_scale_factor_spatial
|
||||
latent_width = width // self.vae_scale_factor_spatial
|
||||
shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
return latents * self.scheduler.config.sigma_max
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: int = 704,
|
||||
width: int = 1280,
|
||||
num_frames: int = 121,
|
||||
num_inference_steps: int = 36,
|
||||
guidance_scale: float = 7.0,
|
||||
fps: int = 30,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
height (`int`, defaults to `720`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, defaults to `1280`):
|
||||
The width in pixels of the generated image.
|
||||
num_frames (`int`, defaults to `129`):
|
||||
The number of frames in the generated video.
|
||||
num_inference_steps (`int`, defaults to `50`):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, defaults to `6.0`):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`.
|
||||
fps (`int`, defaults to `30`):
|
||||
The frames per second of the generated video.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor is generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
|
||||
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple.
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
||||
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
||||
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
||||
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
||||
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~CosmosPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`CosmosPipelineOutput`] is returned, otherwise a `tuple` is returned where
|
||||
the first element is a list with the generated images and the second element is a list of `bool`s
|
||||
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
|
||||
if self.safety_checker is None:
|
||||
raise ValueError(
|
||||
f"You have disabled the safety checker for {self.__class__}. This is in violation of the "
|
||||
"[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). "
|
||||
f"Please ensure that you are compliant with the license agreement."
|
||||
)
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
if self.safety_checker is not None:
|
||||
self.safety_checker.to(device)
|
||||
if prompt is not None:
|
||||
prompt_list = [prompt] if isinstance(prompt, str) else prompt
|
||||
for p in prompt_list:
|
||||
if not self.safety_checker.check_text_safety(p):
|
||||
raise ValueError(
|
||||
f"Cosmos Guardrail detected unsafe text in the prompt: {p}. Please ensure that the "
|
||||
f"prompt abides by the NVIDIA Open Model License Agreement."
|
||||
)
|
||||
self.safety_checker.to("cpu")
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# 3. Encode input prompt
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
device=device,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
transformer_dtype = self.transformer.dtype
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
|
||||
|
||||
# 6. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
timestep = t.expand(latents.shape[0]).to(transformer_dtype)
|
||||
|
||||
latent_model_input = latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
latent_model_input = latent_model_input.to(transformer_dtype)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
fps=fps,
|
||||
padding_mask=padding_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
sample = latents
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
fps=fps,
|
||||
padding_mask=padding_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = torch.cat([noise_pred_uncond, noise_pred])
|
||||
sample = torch.cat([sample, sample])
|
||||
|
||||
# pred_original_sample (x0)
|
||||
noise_pred = self.scheduler.step(noise_pred, t, sample, return_dict=False)[1]
|
||||
self.scheduler._step_index -= 1
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_cond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
||||
|
||||
# pred_sample (eps)
|
||||
latents = self.scheduler.step(
|
||||
noise_pred, t, latents, return_dict=False, pred_original_sample=noise_pred
|
||||
)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
if self.vae.config.latents_mean is not None:
|
||||
latents_mean, latents_std = self.vae.config.latents_mean, self.vae.config.latents_std
|
||||
latents_mean = (
|
||||
torch.tensor(latents_mean)
|
||||
.view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : latents.size(2)]
|
||||
.to(latents)
|
||||
)
|
||||
latents_std = (
|
||||
torch.tensor(latents_std)
|
||||
.view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : latents.size(2)]
|
||||
.to(latents)
|
||||
)
|
||||
latents = latents * latents_std / self.scheduler.config.sigma_data + latents_mean
|
||||
else:
|
||||
latents = latents / self.scheduler.config.sigma_data
|
||||
video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
|
||||
|
||||
if self.safety_checker is not None:
|
||||
self.safety_checker.to(device)
|
||||
video = self.video_processor.postprocess_video(video, output_type="np")
|
||||
video = (video * 255).astype(np.uint8)
|
||||
video_batch = []
|
||||
for vid in video:
|
||||
vid = self.safety_checker.check_video_safety(vid)
|
||||
video_batch.append(vid)
|
||||
video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1
|
||||
video = torch.from_numpy(video).permute(0, 4, 1, 2, 3)
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
self.safety_checker.to("cpu")
|
||||
else:
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return CosmosPipelineOutput(frames=video)
|
||||
@@ -0,0 +1,828 @@
|
||||
# Copyright 2024 The NVIDIA Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...models import AutoencoderKLCosmos, CosmosTransformer3DModel
|
||||
from ...schedulers import EDMEulerScheduler
|
||||
from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import CosmosPipelineOutput
|
||||
|
||||
|
||||
if is_cosmos_guardrail_available():
|
||||
from cosmos_guardrail import CosmosSafetyChecker
|
||||
else:
|
||||
|
||||
class CosmosSafetyChecker:
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise ImportError(
|
||||
"`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`."
|
||||
)
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
Image conditioning:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import CosmosVideoToWorldPipeline
|
||||
>>> from diffusers.utils import export_to_video, load_image
|
||||
|
||||
>>> model_id = "nvidia/Cosmos-1.0-Diffusion-7B-Video2World"
|
||||
>>> pipe = CosmosVideoToWorldPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> prompt = "The video depicts a long, straight highway stretching into the distance, flanked by metal guardrails. The road is divided into multiple lanes, with a few vehicles visible in the far distance. The surrounding landscape features dry, grassy fields on one side and rolling hills on the other. The sky is mostly clear with a few scattered clouds, suggesting a bright, sunny day."
|
||||
>>> image = load_image(
|
||||
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input.jpg"
|
||||
... )
|
||||
|
||||
>>> video = pipe(image=image, prompt=prompt).frames[0]
|
||||
>>> export_to_video(video, "output.mp4", fps=30)
|
||||
```
|
||||
|
||||
Video conditioning:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import CosmosVideoToWorldPipeline
|
||||
>>> from diffusers.utils import export_to_video, load_video
|
||||
|
||||
>>> model_id = "nvidia/Cosmos-1.0-Diffusion-7B-Video2World"
|
||||
>>> pipe = CosmosVideoToWorldPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
|
||||
>>> pipe.transformer = torch.compile(pipe.transformer)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> prompt = "The video depicts a winding mountain road covered in snow, with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the solitude and beauty of a winter drive through a mountainous region."
|
||||
>>> video = load_video(
|
||||
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input-vid.mp4"
|
||||
... )[
|
||||
... :21
|
||||
... ] # This example uses only the first 21 frames
|
||||
|
||||
>>> video = pipe(video=video, prompt=prompt).frames[0]
|
||||
>>> export_to_video(video, "output.mp4", fps=30)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class CosmosVideoToWorldPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for image-to-video and video-to-video generation using [Cosmos](https://github.com/NVIDIA/Cosmos).
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
Args:
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
Frozen text-encoder. Cosmos uses
|
||||
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
|
||||
[t5-11b](https://huggingface.co/google-t5/t5-11b) variant.
|
||||
tokenizer (`T5TokenizerFast`):
|
||||
Tokenizer of class
|
||||
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
||||
transformer ([`CosmosTransformer3DModel`]):
|
||||
Conditional Transformer to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLCosmos`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
# We mark safety_checker as optional here to get around some test failures, but it is not really optional
|
||||
_optional_components = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_encoder: T5EncoderModel,
|
||||
tokenizer: T5TokenizerFast,
|
||||
transformer: CosmosTransformer3DModel,
|
||||
vae: AutoencoderKLCosmos,
|
||||
scheduler: EDMEulerScheduler,
|
||||
safety_checker: CosmosSafetyChecker = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if safety_checker is None:
|
||||
safety_checker = CosmosSafetyChecker()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
)
|
||||
|
||||
self.vae_scale_factor_temporal = (
|
||||
self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 8
|
||||
)
|
||||
self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 8
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
# Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline._get_t5_prompt_embeds
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
max_sequence_length: int = 512,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
return_length=True,
|
||||
return_offsets_mapping=False,
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
prompt_attention_mask = text_inputs.attention_mask.bool().to(device)
|
||||
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(device), attention_mask=prompt_attention_mask
|
||||
).last_hidden_state
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
lengths = prompt_attention_mask.sum(dim=1).cpu()
|
||||
for i, length in enumerate(lengths):
|
||||
prompt_embeds[i, length:] = 0
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 512,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier free guidance or not.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = negative_prompt_embeds.shape
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
video: torch.Tensor,
|
||||
batch_size: int,
|
||||
num_channels_latents: 16,
|
||||
height: int = 704,
|
||||
width: int = 1280,
|
||||
num_frames: int = 121,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
input_frames_guidance: bool = False,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
num_cond_frames = video.size(2)
|
||||
if num_cond_frames >= num_frames:
|
||||
# Take the last `num_frames` frames for conditioning
|
||||
num_cond_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
video = video[:, :, -num_frames:]
|
||||
else:
|
||||
num_cond_latent_frames = (num_cond_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
num_padding_frames = num_frames - num_cond_frames
|
||||
padding = video.new_zeros(video.size(0), video.size(1), num_padding_frames, video.size(3), video.size(4))
|
||||
video = torch.cat([video, padding], dim=2)
|
||||
|
||||
if isinstance(generator, list):
|
||||
init_latents = [
|
||||
retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator=generator[i])
|
||||
for i in range(batch_size)
|
||||
]
|
||||
else:
|
||||
init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
|
||||
|
||||
init_latents = torch.cat(init_latents, dim=0).to(dtype)
|
||||
|
||||
if self.vae.config.latents_mean is not None:
|
||||
latents_mean, latents_std = self.vae.config.latents_mean, self.vae.config.latents_std
|
||||
latents_mean = (
|
||||
torch.tensor(latents_mean)
|
||||
.view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : init_latents.size(2)]
|
||||
.to(init_latents)
|
||||
)
|
||||
latents_std = (
|
||||
torch.tensor(latents_std)
|
||||
.view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : init_latents.size(2)]
|
||||
.to(init_latents)
|
||||
)
|
||||
init_latents = (init_latents - latents_mean) * self.scheduler.config.sigma_data / latents_std
|
||||
else:
|
||||
init_latents = init_latents * self.scheduler.config.sigma_data
|
||||
|
||||
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
latent_height = height // self.vae_scale_factor_spatial
|
||||
latent_width = width // self.vae_scale_factor_spatial
|
||||
shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
|
||||
latents = latents * self.scheduler.config.sigma_max
|
||||
|
||||
padding_shape = (batch_size, 1, num_latent_frames, latent_height, latent_width)
|
||||
ones_padding = latents.new_ones(padding_shape)
|
||||
zeros_padding = latents.new_zeros(padding_shape)
|
||||
|
||||
cond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1)
|
||||
cond_indicator[:, :, :num_cond_latent_frames] = 1.0
|
||||
cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding
|
||||
|
||||
uncond_indicator = uncond_mask = None
|
||||
if do_classifier_free_guidance:
|
||||
uncond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1)
|
||||
uncond_indicator[:, :, :num_cond_latent_frames] = 1.0
|
||||
uncond_mask = zeros_padding
|
||||
if not input_frames_guidance:
|
||||
uncond_mask = uncond_indicator * ones_padding + (1 - uncond_indicator) * zeros_padding
|
||||
|
||||
return latents, init_latents, cond_indicator, uncond_indicator, cond_mask, uncond_mask
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
image=None,
|
||||
video=None,
|
||||
):
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if image is None and video is None:
|
||||
raise ValueError("Either `image` or `video` has to be provided.")
|
||||
if image is not None and video is not None:
|
||||
raise ValueError("Only one of `image` or `video` has to be provided.")
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
image: PipelineImageInput = None,
|
||||
video: List[PipelineImageInput] = None,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: int = 704,
|
||||
width: int = 1280,
|
||||
num_frames: int = 121,
|
||||
num_inference_steps: int = 36,
|
||||
guidance_scale: float = 7.0,
|
||||
input_frames_guidance: bool = False,
|
||||
augment_sigma: float = 0.001,
|
||||
fps: int = 30,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
height (`int`, defaults to `720`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, defaults to `1280`):
|
||||
The width in pixels of the generated image.
|
||||
num_frames (`int`, defaults to `129`):
|
||||
The number of frames in the generated video.
|
||||
num_inference_steps (`int`, defaults to `50`):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, defaults to `6.0`):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`.
|
||||
fps (`int`, defaults to `30`):
|
||||
The frames per second of the generated video.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor is generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
|
||||
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple.
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
||||
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
||||
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
||||
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
||||
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~CosmosPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`CosmosPipelineOutput`] is returned, otherwise a `tuple` is returned where
|
||||
the first element is a list with the generated images and the second element is a list of `bool`s
|
||||
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
|
||||
if self.safety_checker is None:
|
||||
raise ValueError(
|
||||
f"You have disabled the safety checker for {self.__class__}. This is in violation of the "
|
||||
"[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). "
|
||||
f"Please ensure that you are compliant with the license agreement."
|
||||
)
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs, image, video)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
if self.safety_checker is not None:
|
||||
self.safety_checker.to(device)
|
||||
if prompt is not None:
|
||||
prompt_list = [prompt] if isinstance(prompt, str) else prompt
|
||||
for p in prompt_list:
|
||||
if not self.safety_checker.check_text_safety(p):
|
||||
raise ValueError(
|
||||
f"Cosmos Guardrail detected unsafe text in the prompt: {p}. Please ensure that the "
|
||||
f"prompt abides by the NVIDIA Open Model License Agreement."
|
||||
)
|
||||
self.safety_checker.to("cpu")
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# 3. Encode input prompt
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
device=device,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
vae_dtype = self.vae.dtype
|
||||
transformer_dtype = self.transformer.dtype
|
||||
|
||||
if image is not None:
|
||||
video = self.video_processor.preprocess(image, height, width).unsqueeze(2)
|
||||
else:
|
||||
video = self.video_processor.preprocess_video(video, height, width)
|
||||
video = video.to(device=device, dtype=vae_dtype)
|
||||
|
||||
num_channels_latents = self.transformer.config.in_channels - 1
|
||||
latents, conditioning_latents, cond_indicator, uncond_indicator, cond_mask, uncond_mask = self.prepare_latents(
|
||||
video,
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
self.do_classifier_free_guidance,
|
||||
input_frames_guidance,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
cond_mask = cond_mask.to(transformer_dtype)
|
||||
if self.do_classifier_free_guidance:
|
||||
uncond_mask = uncond_mask.to(transformer_dtype)
|
||||
|
||||
augment_sigma = torch.tensor([augment_sigma], device=device, dtype=torch.float32)
|
||||
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
|
||||
|
||||
# 6. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
timestep = t.expand(latents.shape[0]).to(transformer_dtype)
|
||||
|
||||
current_sigma = self.scheduler.sigmas[i]
|
||||
is_augment_sigma_greater = augment_sigma >= current_sigma
|
||||
|
||||
c_in_augment = self.scheduler._get_conditioning_c_in(augment_sigma)
|
||||
c_in_original = self.scheduler._get_conditioning_c_in(current_sigma)
|
||||
|
||||
current_cond_indicator = cond_indicator * 0 if is_augment_sigma_greater else cond_indicator
|
||||
cond_noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=torch.float32)
|
||||
cond_latent = conditioning_latents + cond_noise * augment_sigma[:, None, None, None, None]
|
||||
cond_latent = cond_latent * c_in_augment / c_in_original
|
||||
cond_latent = current_cond_indicator * cond_latent + (1 - current_cond_indicator) * latents
|
||||
cond_latent = self.scheduler.scale_model_input(cond_latent, t)
|
||||
cond_latent = cond_latent.to(transformer_dtype)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=cond_latent,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
fps=fps,
|
||||
condition_mask=cond_mask,
|
||||
padding_mask=padding_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
sample = latents
|
||||
if self.do_classifier_free_guidance:
|
||||
current_uncond_indicator = uncond_indicator * 0 if is_augment_sigma_greater else uncond_indicator
|
||||
uncond_noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=torch.float32)
|
||||
uncond_latent = conditioning_latents + uncond_noise * augment_sigma[:, None, None, None, None]
|
||||
uncond_latent = uncond_latent * c_in_augment / c_in_original
|
||||
uncond_latent = current_uncond_indicator * uncond_latent + (1 - current_uncond_indicator) * latents
|
||||
uncond_latent = self.scheduler.scale_model_input(uncond_latent, t)
|
||||
uncond_latent = uncond_latent.to(transformer_dtype)
|
||||
|
||||
noise_pred_uncond = self.transformer(
|
||||
hidden_states=uncond_latent,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
fps=fps,
|
||||
condition_mask=uncond_mask,
|
||||
padding_mask=padding_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = torch.cat([noise_pred_uncond, noise_pred])
|
||||
sample = torch.cat([sample, sample])
|
||||
|
||||
# pred_original_sample (x0)
|
||||
noise_pred = self.scheduler.step(noise_pred, t, sample, return_dict=False)[1]
|
||||
self.scheduler._step_index -= 1
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0)
|
||||
noise_pred_uncond = (
|
||||
current_uncond_indicator * conditioning_latents
|
||||
+ (1 - current_uncond_indicator) * noise_pred_uncond
|
||||
)
|
||||
noise_pred_cond = (
|
||||
current_cond_indicator * conditioning_latents + (1 - current_cond_indicator) * noise_pred_cond
|
||||
)
|
||||
noise_pred = noise_pred_cond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
||||
else:
|
||||
noise_pred = (
|
||||
current_cond_indicator * conditioning_latents + (1 - current_cond_indicator) * noise_pred
|
||||
)
|
||||
|
||||
# pred_sample (eps)
|
||||
latents = self.scheduler.step(
|
||||
noise_pred, t, latents, return_dict=False, pred_original_sample=noise_pred
|
||||
)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
if self.vae.config.latents_mean is not None:
|
||||
latents_mean, latents_std = self.vae.config.latents_mean, self.vae.config.latents_std
|
||||
latents_mean = (
|
||||
torch.tensor(latents_mean)
|
||||
.view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : latents.size(2)]
|
||||
.to(latents)
|
||||
)
|
||||
latents_std = (
|
||||
torch.tensor(latents_std)
|
||||
.view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : latents.size(2)]
|
||||
.to(latents)
|
||||
)
|
||||
latents = latents * latents_std / self.scheduler.config.sigma_data + latents_mean
|
||||
else:
|
||||
latents = latents / self.scheduler.config.sigma_data
|
||||
video = self.vae.decode(latents.to(vae_dtype), return_dict=False)[0]
|
||||
|
||||
if self.safety_checker is not None:
|
||||
self.safety_checker.to(device)
|
||||
video = self.video_processor.postprocess_video(video, output_type="np")
|
||||
video = (video * 255).astype(np.uint8)
|
||||
video_batch = []
|
||||
for vid in video:
|
||||
vid = self.safety_checker.check_video_safety(vid)
|
||||
video_batch.append(vid)
|
||||
video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1
|
||||
video = torch.from_numpy(video).permute(0, 4, 1, 2, 3)
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
self.safety_checker.to("cpu")
|
||||
else:
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return CosmosPipelineOutput(frames=video)
|
||||
@@ -0,0 +1,20 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class CosmosPipelineOutput(BaseOutput):
|
||||
r"""
|
||||
Output class for Cosmos pipelines.
|
||||
|
||||
Args:
|
||||
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
||||
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
||||
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
||||
`(batch_size, num_frames, channels, height, width)`.
|
||||
"""
|
||||
|
||||
frames: torch.Tensor
|
||||
@@ -687,11 +687,11 @@ class FluxPipeline(
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
guidance_scale (`float`, *optional*, defaults to 3.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
Guidance scale as defined in [Classifier-Free Diffusion
|
||||
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
||||
the text `prompt`, usually at the expense of lower image quality.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
@@ -700,7 +700,7 @@ class FluxPipeline(
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
tensor will be generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
|
||||
@@ -607,6 +607,39 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
||||
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_slicing
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.vae.enable_slicing()
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_slicing
|
||||
def disable_vae_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_slicing()
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_tiling
|
||||
def enable_vae_tiling(self):
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
self.vae.enable_tiling()
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling
|
||||
def disable_vae_tiling(self):
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
image,
|
||||
|
||||
@@ -36,11 +36,11 @@ EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from transformers import PreTrainedTokenizerFast, LlamaForCausalLM
|
||||
>>> from diffusers import UniPCMultistepScheduler, HiDreamImagePipeline
|
||||
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
||||
>>> from diffusers import HiDreamImagePipeline
|
||||
|
||||
|
||||
>>> tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
|
||||
>>> tokenizer_4 = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
|
||||
>>> text_encoder_4 = LlamaForCausalLM.from_pretrained(
|
||||
... "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
... output_hidden_states=True,
|
||||
|
||||
@@ -24,6 +24,7 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["pipeline_hunyuan_skyreels_image2video"] = ["HunyuanSkyreelsImageToVideoPipeline"]
|
||||
_import_structure["pipeline_hunyuan_video"] = ["HunyuanVideoPipeline"]
|
||||
_import_structure["pipeline_hunyuan_video_framepack"] = ["HunyuanVideoFramepackPipeline"]
|
||||
_import_structure["pipeline_hunyuan_video_image2video"] = ["HunyuanVideoImageToVideoPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
@@ -36,6 +37,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .pipeline_hunyuan_skyreels_image2video import HunyuanSkyreelsImageToVideoPipeline
|
||||
from .pipeline_hunyuan_video import HunyuanVideoPipeline
|
||||
from .pipeline_hunyuan_video_framepack import HunyuanVideoFramepackPipeline
|
||||
from .pipeline_hunyuan_video_image2video import HunyuanVideoImageToVideoPipeline
|
||||
|
||||
else:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,8 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from diffusers.utils import BaseOutput
|
||||
@@ -18,3 +21,19 @@ class HunyuanVideoPipelineOutput(BaseOutput):
|
||||
"""
|
||||
|
||||
frames: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class HunyuanVideoFramepackPipelineOutput(BaseOutput):
|
||||
r"""
|
||||
Output class for HunyuanVideo pipelines.
|
||||
|
||||
Args:
|
||||
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
||||
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
||||
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
||||
`(batch_size, num_frames, channels, height, width)`. Or, a list of torch tensors where each tensor
|
||||
corresponds to a latent that decodes to multiple frames.
|
||||
"""
|
||||
|
||||
frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]], List[torch.Tensor]]
|
||||
|
||||
@@ -22,9 +22,11 @@ except OptionalDependencyNotAvailable:
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["modeling_latent_upsampler"] = ["LTXLatentUpsamplerModel"]
|
||||
_import_structure["pipeline_ltx"] = ["LTXPipeline"]
|
||||
_import_structure["pipeline_ltx_condition"] = ["LTXConditionPipeline"]
|
||||
_import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"]
|
||||
_import_structure["pipeline_ltx_latent_upsample"] = ["LTXLatentUpsamplePipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
@@ -34,9 +36,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .modeling_latent_upsampler import LTXLatentUpsamplerModel
|
||||
from .pipeline_ltx import LTXPipeline
|
||||
from .pipeline_ltx_condition import LTXConditionPipeline
|
||||
from .pipeline_ltx_image2video import LTXImageToVideoPipeline
|
||||
from .pipeline_ltx_latent_upsample import LTXLatentUpsamplePipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
@@ -0,0 +1,188 @@
|
||||
# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
|
||||
|
||||
class ResBlock(torch.nn.Module):
|
||||
def __init__(self, channels: int, mid_channels: Optional[int] = None, dims: int = 3):
|
||||
super().__init__()
|
||||
if mid_channels is None:
|
||||
mid_channels = channels
|
||||
|
||||
Conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
|
||||
|
||||
self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1)
|
||||
self.norm1 = torch.nn.GroupNorm(32, mid_channels)
|
||||
self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1)
|
||||
self.norm2 = torch.nn.GroupNorm(32, channels)
|
||||
self.activation = torch.nn.SiLU()
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = self.activation(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
hidden_states = self.activation(hidden_states + residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class PixelShuffleND(torch.nn.Module):
|
||||
def __init__(self, dims, upscale_factors=(2, 2, 2)):
|
||||
super().__init__()
|
||||
|
||||
self.dims = dims
|
||||
self.upscale_factors = upscale_factors
|
||||
|
||||
if dims not in [1, 2, 3]:
|
||||
raise ValueError("dims must be 1, 2, or 3")
|
||||
|
||||
def forward(self, x):
|
||||
if self.dims == 3:
|
||||
# spatiotemporal: b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)
|
||||
return (
|
||||
x.unflatten(1, (-1, *self.upscale_factors[:3]))
|
||||
.permute(0, 1, 5, 2, 6, 3, 7, 4)
|
||||
.flatten(6, 7)
|
||||
.flatten(4, 5)
|
||||
.flatten(2, 3)
|
||||
)
|
||||
elif self.dims == 2:
|
||||
# spatial: b (c p1 p2) h w -> b c (h p1) (w p2)
|
||||
return (
|
||||
x.unflatten(1, (-1, *self.upscale_factors[:2])).permute(0, 1, 4, 2, 5, 3).flatten(4, 5).flatten(2, 3)
|
||||
)
|
||||
elif self.dims == 1:
|
||||
# temporal: b (c p1) f h w -> b c (f p1) h w
|
||||
return x.unflatten(1, (-1, *self.upscale_factors[:1])).permute(0, 1, 3, 2, 4, 5).flatten(2, 3)
|
||||
|
||||
|
||||
class LTXLatentUpsamplerModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
Model to spatially upsample VAE latents.
|
||||
|
||||
Args:
|
||||
in_channels (`int`, defaults to `128`):
|
||||
Number of channels in the input latent
|
||||
mid_channels (`int`, defaults to `512`):
|
||||
Number of channels in the middle layers
|
||||
num_blocks_per_stage (`int`, defaults to `4`):
|
||||
Number of ResBlocks to use in each stage (pre/post upsampling)
|
||||
dims (`int`, defaults to `3`):
|
||||
Number of dimensions for convolutions (2 or 3)
|
||||
spatial_upsample (`bool`, defaults to `True`):
|
||||
Whether to spatially upsample the latent
|
||||
temporal_upsample (`bool`, defaults to `False`):
|
||||
Whether to temporally upsample the latent
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 128,
|
||||
mid_channels: int = 512,
|
||||
num_blocks_per_stage: int = 4,
|
||||
dims: int = 3,
|
||||
spatial_upsample: bool = True,
|
||||
temporal_upsample: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.mid_channels = mid_channels
|
||||
self.num_blocks_per_stage = num_blocks_per_stage
|
||||
self.dims = dims
|
||||
self.spatial_upsample = spatial_upsample
|
||||
self.temporal_upsample = temporal_upsample
|
||||
|
||||
ConvNd = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
|
||||
|
||||
self.initial_conv = ConvNd(in_channels, mid_channels, kernel_size=3, padding=1)
|
||||
self.initial_norm = torch.nn.GroupNorm(32, mid_channels)
|
||||
self.initial_activation = torch.nn.SiLU()
|
||||
|
||||
self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)])
|
||||
|
||||
if spatial_upsample and temporal_upsample:
|
||||
self.upsampler = torch.nn.Sequential(
|
||||
torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
|
||||
PixelShuffleND(3),
|
||||
)
|
||||
elif spatial_upsample:
|
||||
self.upsampler = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
|
||||
PixelShuffleND(2),
|
||||
)
|
||||
elif temporal_upsample:
|
||||
self.upsampler = torch.nn.Sequential(
|
||||
torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
|
||||
PixelShuffleND(1),
|
||||
)
|
||||
else:
|
||||
raise ValueError("Either spatial_upsample or temporal_upsample must be True")
|
||||
|
||||
self.post_upsample_res_blocks = torch.nn.ModuleList(
|
||||
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
|
||||
)
|
||||
|
||||
self.final_conv = ConvNd(mid_channels, in_channels, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
|
||||
if self.dims == 2:
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
|
||||
hidden_states = self.initial_conv(hidden_states)
|
||||
hidden_states = self.initial_norm(hidden_states)
|
||||
hidden_states = self.initial_activation(hidden_states)
|
||||
|
||||
for block in self.res_blocks:
|
||||
hidden_states = block(hidden_states)
|
||||
|
||||
hidden_states = self.upsampler(hidden_states)
|
||||
|
||||
for block in self.post_upsample_res_blocks:
|
||||
hidden_states = block(hidden_states)
|
||||
|
||||
hidden_states = self.final_conv(hidden_states)
|
||||
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
|
||||
else:
|
||||
hidden_states = self.initial_conv(hidden_states)
|
||||
hidden_states = self.initial_norm(hidden_states)
|
||||
hidden_states = self.initial_activation(hidden_states)
|
||||
|
||||
for block in self.res_blocks:
|
||||
hidden_states = block(hidden_states)
|
||||
|
||||
if self.temporal_upsample:
|
||||
hidden_states = self.upsampler(hidden_states)
|
||||
hidden_states = hidden_states[:, :, 1:, :, :]
|
||||
else:
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
|
||||
hidden_states = self.upsampler(hidden_states)
|
||||
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
|
||||
|
||||
for block in self.post_upsample_res_blocks:
|
||||
hidden_states = block(hidden_states)
|
||||
|
||||
hidden_states = self.final_conv(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
@@ -789,6 +789,7 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
|
||||
]
|
||||
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
|
||||
|
||||
latents = latents.to(self.vae.dtype)
|
||||
video = self.vae.decode(latents, timestep, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
|
||||
|
||||
@@ -430,6 +430,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
video,
|
||||
frame_index,
|
||||
strength,
|
||||
denoise_strength,
|
||||
height,
|
||||
width,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
@@ -497,6 +498,9 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
elif isinstance(video, list) and isinstance(strength, list) and len(video) != len(strength):
|
||||
raise ValueError("If `conditions` is not provided, `video` and `strength` must be of the same length.")
|
||||
|
||||
if denoise_strength < 0 or denoise_strength > 1:
|
||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {denoise_strength}")
|
||||
|
||||
@staticmethod
|
||||
def _prepare_video_ids(
|
||||
batch_size: int,
|
||||
@@ -649,6 +653,8 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
width: int = 704,
|
||||
num_frames: int = 161,
|
||||
num_prefix_latent_frames: int = 2,
|
||||
sigma: Optional[torch.Tensor] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
@@ -658,7 +664,18 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
latent_width = width // self.vae_spatial_compression_ratio
|
||||
|
||||
shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
if latents is not None and sigma is not None:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(
|
||||
f"Latents shape {latents.shape} does not match expected shape {shape}. Please check the input."
|
||||
)
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
sigma = sigma.to(device=device, dtype=dtype)
|
||||
latents = sigma * noise + (1 - sigma) * latents
|
||||
else:
|
||||
latents = noise
|
||||
|
||||
if len(conditions) > 0:
|
||||
condition_latent_frames_mask = torch.zeros(
|
||||
@@ -766,6 +783,13 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
|
||||
return latents, conditioning_mask, video_ids, extra_conditioning_num_latents
|
||||
|
||||
def get_timesteps(self, sigmas, timesteps, num_inference_steps, strength):
|
||||
num_steps = min(int(num_inference_steps * strength), num_inference_steps)
|
||||
start_index = max(num_inference_steps - num_steps, 0)
|
||||
sigmas = sigmas[start_index:]
|
||||
timesteps = timesteps[start_index:]
|
||||
return sigmas, timesteps, num_inference_steps - start_index
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
@@ -799,6 +823,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
video: List[PipelineImageInput] = None,
|
||||
frame_index: Union[int, List[int]] = 0,
|
||||
strength: Union[float, List[float]] = 1.0,
|
||||
denoise_strength: float = 1.0,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: int = 512,
|
||||
@@ -842,6 +867,10 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
generation. If not provided, one has to pass `conditions`.
|
||||
strength (`float` or `List[float]`, *optional*):
|
||||
The strength or strengths of the conditioning effect. If not provided, one has to pass `conditions`.
|
||||
denoise_strength (`float`, defaults to `1.0`):
|
||||
The strength of the noise added to the latents for editing. Higher strength leads to more noise added
|
||||
to the latents, therefore leading to more differences between original video and generated video. This
|
||||
is useful for video-to-video editing.
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
@@ -918,8 +947,6 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
if latents is not None:
|
||||
raise ValueError("Passing latents is not yet supported.")
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
@@ -929,6 +956,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
video=video,
|
||||
frame_index=frame_index,
|
||||
strength=strength,
|
||||
denoise_strength=denoise_strength,
|
||||
height=height,
|
||||
width=width,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
@@ -977,8 +1005,9 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
strength = [strength] * num_conditions
|
||||
|
||||
device = self._execution_device
|
||||
vae_dtype = self.vae.dtype
|
||||
|
||||
# 3. Prepare text embeddings
|
||||
# 3. Prepare text embeddings & conditioning image/video
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
@@ -1000,8 +1029,6 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||
|
||||
vae_dtype = self.vae.dtype
|
||||
|
||||
conditioning_tensors = []
|
||||
is_conditioning_image_or_video = image is not None or video is not None
|
||||
if is_conditioning_image_or_video:
|
||||
@@ -1032,7 +1059,25 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
)
|
||||
conditioning_tensors.append(condition_tensor)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
# 4. Prepare timesteps
|
||||
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
|
||||
latent_height = height // self.vae_spatial_compression_ratio
|
||||
latent_width = width // self.vae_spatial_compression_ratio
|
||||
sigmas = linear_quadratic_schedule(num_inference_steps)
|
||||
timesteps = sigmas * 1000
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
latent_sigma = None
|
||||
if denoise_strength < 1:
|
||||
sigmas, timesteps, num_inference_steps = self.get_timesteps(
|
||||
sigmas, timesteps, num_inference_steps, denoise_strength
|
||||
)
|
||||
latent_sigma = sigmas[:1].repeat(batch_size * num_videos_per_prompt)
|
||||
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
latents, conditioning_mask, video_coords, extra_conditioning_num_latents = self.prepare_latents(
|
||||
conditioning_tensors,
|
||||
@@ -1043,6 +1088,8 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
sigma=latent_sigma,
|
||||
latents=latents,
|
||||
generator=generator,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
@@ -1056,21 +1103,6 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
if self.do_classifier_free_guidance:
|
||||
video_coords = torch.cat([video_coords, video_coords], dim=0)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
|
||||
latent_height = height // self.vae_spatial_compression_ratio
|
||||
latent_width = width // self.vae_spatial_compression_ratio
|
||||
sigmas = linear_quadratic_schedule(num_inference_steps)
|
||||
timesteps = sigmas * 1000
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
timesteps=timesteps,
|
||||
)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
@@ -1168,7 +1200,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
if not self.vae.config.timestep_conditioning:
|
||||
timestep = None
|
||||
else:
|
||||
noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype)
|
||||
noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype)
|
||||
if not isinstance(decode_timestep, list):
|
||||
decode_timestep = [decode_timestep] * batch_size
|
||||
if decode_noise_scale is None:
|
||||
|
||||
@@ -0,0 +1,241 @@
|
||||
# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...models import AutoencoderKLLTXVideo
|
||||
from ...utils import get_logger
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .modeling_latent_upsampler import LTXLatentUpsamplerModel
|
||||
from .pipeline_output import LTXPipelineOutput
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class LTXLatentUpsamplePipeline(DiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKLLTXVideo,
|
||||
latent_upsampler: LTXLatentUpsamplerModel,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(vae=vae, latent_upsampler=latent_upsampler)
|
||||
|
||||
self.vae_spatial_compression_ratio = (
|
||||
self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32
|
||||
)
|
||||
self.vae_temporal_compression_ratio = (
|
||||
self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
|
||||
)
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
video: Optional[torch.Tensor] = None,
|
||||
batch_size: int = 1,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
video = video.to(device=device, dtype=self.vae.dtype)
|
||||
if isinstance(generator, list):
|
||||
if len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
init_latents = [
|
||||
retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
|
||||
]
|
||||
else:
|
||||
init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
|
||||
|
||||
init_latents = torch.cat(init_latents, dim=0).to(dtype)
|
||||
init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std)
|
||||
return init_latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._normalize_latents
|
||||
def _normalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
# Normalize latents across the channel dimension [B, C, F, H, W]
|
||||
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = (latents - latents_mean) * scaling_factor / latents_std
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._denormalize_latents
|
||||
def _denormalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
# Denormalize latents across the channel dimension [B, C, F, H, W]
|
||||
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = latents * latents_std / scaling_factor + latents_mean
|
||||
return latents
|
||||
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.vae.enable_slicing()
|
||||
|
||||
def disable_vae_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_slicing()
|
||||
|
||||
def enable_vae_tiling(self):
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
self.vae.enable_tiling()
|
||||
|
||||
def disable_vae_tiling(self):
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
def check_inputs(self, video, height, width, latents):
|
||||
if height % self.vae_spatial_compression_ratio != 0 or width % self.vae_spatial_compression_ratio != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
|
||||
|
||||
if video is not None and latents is not None:
|
||||
raise ValueError("Only one of `video` or `latents` can be provided.")
|
||||
if video is None and latents is None:
|
||||
raise ValueError("One of `video` or `latents` has to be provided.")
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
video: Optional[List[PipelineImageInput]] = None,
|
||||
height: int = 512,
|
||||
width: int = 704,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
decode_timestep: Union[float, List[float]] = 0.0,
|
||||
decode_noise_scale: Optional[Union[float, List[float]]] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
):
|
||||
self.check_inputs(
|
||||
video=video,
|
||||
height=height,
|
||||
width=width,
|
||||
latents=latents,
|
||||
)
|
||||
|
||||
if video is not None:
|
||||
# Batched video input is not yet tested/supported. TODO: take a look later
|
||||
batch_size = 1
|
||||
else:
|
||||
batch_size = latents.shape[0]
|
||||
device = self._execution_device
|
||||
|
||||
if video is not None:
|
||||
num_frames = len(video)
|
||||
if num_frames % self.vae_temporal_compression_ratio != 1:
|
||||
num_frames = (
|
||||
num_frames // self.vae_temporal_compression_ratio * self.vae_temporal_compression_ratio + 1
|
||||
)
|
||||
video = video[:num_frames]
|
||||
logger.warning(
|
||||
f"Video length expected to be of the form `k * {self.vae_temporal_compression_ratio} + 1` but is {len(video)}. Truncating to {num_frames} frames."
|
||||
)
|
||||
video = self.video_processor.preprocess_video(video, height=height, width=width)
|
||||
video = video.to(device=device, dtype=torch.float32)
|
||||
|
||||
latents = self.prepare_latents(
|
||||
video=video,
|
||||
batch_size=batch_size,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
)
|
||||
|
||||
latents = self._denormalize_latents(
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
latents = latents.to(self.latent_upsampler.dtype)
|
||||
latents = self.latent_upsampler(latents)
|
||||
|
||||
if output_type == "latent":
|
||||
latents = self._normalize_latents(
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
video = latents
|
||||
else:
|
||||
if not self.vae.config.timestep_conditioning:
|
||||
timestep = None
|
||||
else:
|
||||
noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype)
|
||||
if not isinstance(decode_timestep, list):
|
||||
decode_timestep = [decode_timestep] * batch_size
|
||||
if decode_noise_scale is None:
|
||||
decode_noise_scale = decode_timestep
|
||||
elif not isinstance(decode_noise_scale, list):
|
||||
decode_noise_scale = [decode_noise_scale] * batch_size
|
||||
|
||||
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
|
||||
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
|
||||
:, None, None, None, None
|
||||
]
|
||||
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
|
||||
|
||||
video = self.vae.decode(latents, timestep, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return LTXPipelineOutput(frames=video)
|
||||
@@ -248,9 +248,8 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
pretrained pipeline hosted on the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
using [`~FlaxDiffusionPipeline.save_pretrained`].
|
||||
dtype (`str` or `jnp.dtype`, *optional*):
|
||||
Override the default `jnp.dtype` and load the model under this dtype. If `"auto"`, the dtype is
|
||||
automatically derived from the model's weights.
|
||||
dtype (`jnp.dtype`, *optional*):
|
||||
Override the default `jnp.dtype` and load the model under this dtype.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
@@ -675,8 +675,10 @@ def load_sub_model(
|
||||
use_safetensors: bool,
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]],
|
||||
provider_options: Any,
|
||||
quantization_config: Optional[Any] = None,
|
||||
):
|
||||
"""Helper method to load the module `name` from `library_name` and `class_name`"""
|
||||
from ..quantizers import PipelineQuantizationConfig
|
||||
|
||||
# retrieve class candidates
|
||||
|
||||
@@ -769,6 +771,17 @@ def load_sub_model(
|
||||
else:
|
||||
loading_kwargs["low_cpu_mem_usage"] = False
|
||||
|
||||
if (
|
||||
quantization_config is not None
|
||||
and isinstance(quantization_config, PipelineQuantizationConfig)
|
||||
and issubclass(class_obj, torch.nn.Module)
|
||||
):
|
||||
model_quant_config = quantization_config._resolve_quant_config(
|
||||
is_diffusers=is_diffusers_model, module_name=name
|
||||
)
|
||||
if model_quant_config is not None:
|
||||
loading_kwargs["quantization_config"] = model_quant_config
|
||||
|
||||
# check if the module is in a subdirectory
|
||||
if dduf_entries:
|
||||
loading_kwargs["dduf_entries"] = dduf_entries
|
||||
|
||||
@@ -47,6 +47,7 @@ from ..configuration_utils import ConfigMixin
|
||||
from ..models import AutoencoderKL
|
||||
from ..models.attention_processor import FusedAttnProcessor2_0
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin
|
||||
from ..quantizers import PipelineQuantizationConfig
|
||||
from ..quantizers.bitsandbytes.utils import _check_bnb_status
|
||||
from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
||||
from ..utils import (
|
||||
@@ -572,12 +573,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
saved using
|
||||
[`~DiffusionPipeline.save_pretrained`].
|
||||
- A path to a *directory* (for example `./my_pipeline_directory/`) containing a dduf file
|
||||
torch_dtype (`str` or `torch.dtype` or `dict[str, Union[str, torch.dtype]]`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
|
||||
dtype is automatically derived from the model's weights. To load submodels with different dtype pass a
|
||||
`dict` (for example `{'transformer': torch.bfloat16, 'vae': torch.float16}`). Set the default dtype for
|
||||
unspecified components with `default` (for example `{'transformer': torch.bfloat16, 'default':
|
||||
torch.float16}`). If a component is not specified and no default is set, `torch.float32` is used.
|
||||
torch_dtype (`torch.dtype` or `dict[str, Union[str, torch.dtype]]`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. To load submodels with
|
||||
different dtype pass a `dict` (for example `{'transformer': torch.bfloat16, 'vae': torch.float16}`).
|
||||
Set the default dtype for unspecified components with `default` (for example `{'transformer':
|
||||
torch.bfloat16, 'default': torch.float16}`). If a component is not specified and no default is set,
|
||||
`torch.float32` is used.
|
||||
custom_pipeline (`str`, *optional*):
|
||||
|
||||
<Tip warning={true}>
|
||||
@@ -725,6 +726,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
use_onnx = kwargs.pop("use_onnx", None)
|
||||
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
||||
quantization_config = kwargs.pop("quantization_config", None)
|
||||
|
||||
if torch_dtype is not None and not isinstance(torch_dtype, dict) and not isinstance(torch_dtype, torch.dtype):
|
||||
torch_dtype = torch.float32
|
||||
@@ -741,6 +743,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if quantization_config is not None and not isinstance(quantization_config, PipelineQuantizationConfig):
|
||||
raise ValueError("`quantization_config` must be an instance of `PipelineQuantizationConfig`.")
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
@@ -1001,6 +1006,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
use_safetensors=use_safetensors,
|
||||
dduf_entries=dduf_entries,
|
||||
provider_options=provider_options,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
logger.info(
|
||||
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
|
||||
|
||||
@@ -30,18 +30,11 @@ except OptionalDependencyNotAvailable:
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["clip_image_project_model"] = ["CLIPImageProjection"]
|
||||
_import_structure["pipeline_cycle_diffusion"] = ["CycleDiffusionPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion"] = ["StableDiffusionPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_attend_and_excite"] = ["StableDiffusionAttendAndExcitePipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_gligen"] = ["StableDiffusionGLIGENPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_gligen_text_image"] = ["StableDiffusionGLIGENTextImagePipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_img2img"] = ["StableDiffusionImg2ImgPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_inpaint"] = ["StableDiffusionInpaintPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_inpaint_legacy"] = ["StableDiffusionInpaintPipelineLegacy"]
|
||||
_import_structure["pipeline_stable_diffusion_instruct_pix2pix"] = ["StableDiffusionInstructPix2PixPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_latent_upscale"] = ["StableDiffusionLatentUpscalePipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_model_editing"] = ["StableDiffusionModelEditingPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_paradigms"] = ["StableDiffusionParadigmsPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_upscale"] = ["StableDiffusionUpscalePipeline"]
|
||||
_import_structure["pipeline_stable_unclip"] = ["StableUnCLIPPipeline"]
|
||||
_import_structure["pipeline_stable_unclip_img2img"] = ["StableUnCLIPImg2ImgPipeline"]
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_visualcloze_combined"] = ["VisualClozePipeline"]
|
||||
_import_structure["pipeline_visualcloze_generation"] = ["VisualClozeGenerationPipeline"]
|
||||
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_visualcloze_combined import VisualClozePipeline
|
||||
from .pipeline_visualcloze_generation import VisualClozeGenerationPipeline
|
||||
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -0,0 +1,444 @@
|
||||
# Copyright 2025 VisualCloze team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.transformers import FluxTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import (
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ..flux.pipeline_flux_fill import FluxFillPipeline as VisualClozeUpsamplingPipeline
|
||||
from ..flux.pipeline_output import FluxPipelineOutput
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_visualcloze_generation import VisualClozeGenerationPipeline
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import VisualClozePipeline
|
||||
>>> from diffusers.utils import load_image
|
||||
|
||||
>>> image_paths = [
|
||||
... # in-context examples
|
||||
... [
|
||||
... load_image(
|
||||
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_mask.jpg"
|
||||
... ),
|
||||
... load_image(
|
||||
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_image.jpg"
|
||||
... ),
|
||||
... ],
|
||||
... # query with the target image
|
||||
... [
|
||||
... load_image(
|
||||
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_query_mask.jpg"
|
||||
... ),
|
||||
... None, # No image needed for the target image
|
||||
... ],
|
||||
... ]
|
||||
>>> task_prompt = "In each row, a logical task is demonstrated to achieve [IMAGE2] an aesthetically pleasing photograph based on [IMAGE1] sam 2-generated masks with rich color coding."
|
||||
>>> content_prompt = "Majestic photo of a golden eagle perched on a rocky outcrop in a mountainous landscape. The eagle is positioned in the right foreground, facing left, with its sharp beak and keen eyes prominently visible. Its plumage is a mix of dark brown and golden hues, with intricate feather details. The background features a soft-focus view of snow-capped mountains under a cloudy sky, creating a serene and grandiose atmosphere. The foreground includes rugged rocks and patches of green moss. Photorealistic, medium depth of field, soft natural lighting, cool color palette, high contrast, sharp focus on the eagle, blurred background, tranquil, majestic, wildlife photography."
|
||||
>>> pipe = VisualClozePipeline.from_pretrained(
|
||||
... "VisualCloze/VisualClozePipeline-384", resolution=384, torch_dtype=torch.bfloat16
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> image = pipe(
|
||||
... task_prompt=task_prompt,
|
||||
... content_prompt=content_prompt,
|
||||
... image=image_paths,
|
||||
... upsampling_width=1344,
|
||||
... upsampling_height=768,
|
||||
... upsampling_strength=0.4,
|
||||
... guidance_scale=30,
|
||||
... num_inference_steps=30,
|
||||
... max_sequence_length=512,
|
||||
... generator=torch.Generator("cpu").manual_seed(0),
|
||||
... ).images[0][0]
|
||||
>>> image.save("visualcloze.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class VisualClozePipeline(
|
||||
DiffusionPipeline,
|
||||
FluxLoraLoaderMixin,
|
||||
FromSingleFileMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
):
|
||||
r"""
|
||||
The VisualCloze pipeline for image generation with visual context. Reference:
|
||||
https://github.com/lzyhha/VisualCloze/tree/main. This pipeline is designed to generate images based on visual
|
||||
in-context examples.
|
||||
|
||||
Args:
|
||||
transformer ([`FluxTransformer2DModel`]):
|
||||
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
text_encoder_2 ([`T5EncoderModel`]):
|
||||
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
||||
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
tokenizer_2 (`T5TokenizerFast`):
|
||||
Second Tokenizer of class
|
||||
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
||||
resolution (`int`, *optional*, defaults to 384):
|
||||
The resolution of each image when concatenating images from the query and in-context examples.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
||||
_optional_components = []
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder_2: T5EncoderModel,
|
||||
tokenizer_2: T5TokenizerFast,
|
||||
transformer: FluxTransformer2DModel,
|
||||
resolution: int = 384,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_2=tokenizer_2,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
self.generation_pipe = VisualClozeGenerationPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_2=tokenizer_2,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
resolution=resolution,
|
||||
)
|
||||
self.upsampling_pipe = VisualClozeUpsamplingPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_2=tokenizer_2,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
image,
|
||||
task_prompt,
|
||||
content_prompt,
|
||||
upsampling_height,
|
||||
upsampling_width,
|
||||
strength,
|
||||
prompt_embeds=None,
|
||||
pooled_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
max_sequence_length=None,
|
||||
):
|
||||
if strength < 0 or strength > 1:
|
||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||
|
||||
if upsampling_height is not None and upsampling_height % (self.vae_scale_factor * 2) != 0:
|
||||
logger.warning(
|
||||
f"`upsampling_height`has to be divisible by {self.vae_scale_factor * 2} but are {upsampling_height}. Dimensions will be resized accordingly"
|
||||
)
|
||||
if upsampling_width is not None and upsampling_width % (self.vae_scale_factor * 2) != 0:
|
||||
logger.warning(
|
||||
f"`upsampling_width` have to be divisible by {self.vae_scale_factor * 2} but are {upsampling_width}. Dimensions will be resized accordingly"
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
# Validate prompt inputs
|
||||
if (task_prompt is not None or content_prompt is not None) and prompt_embeds is not None:
|
||||
raise ValueError("Cannot provide both text `task_prompt` + `content_prompt` and `prompt_embeds`. ")
|
||||
|
||||
if task_prompt is None and content_prompt is None and prompt_embeds is None:
|
||||
raise ValueError("Must provide either `task_prompt` + `content_prompt` or pre-computed `prompt_embeds`. ")
|
||||
|
||||
# Validate prompt types and consistency
|
||||
if task_prompt is None:
|
||||
raise ValueError("`task_prompt` is missing.")
|
||||
|
||||
if task_prompt is not None and not isinstance(task_prompt, (str, list)):
|
||||
raise ValueError(f"`task_prompt` must be str or list, got {type(task_prompt)}")
|
||||
|
||||
if content_prompt is not None and not isinstance(content_prompt, (str, list)):
|
||||
raise ValueError(f"`content_prompt` must be str or list, got {type(content_prompt)}")
|
||||
|
||||
if isinstance(task_prompt, list) or isinstance(content_prompt, list):
|
||||
if not isinstance(task_prompt, list) or not isinstance(content_prompt, list):
|
||||
raise ValueError(
|
||||
f"`task_prompt` and `content_prompt` must both be lists, or both be of type str or None, "
|
||||
f"got {type(task_prompt)} and {type(content_prompt)}"
|
||||
)
|
||||
if len(content_prompt) != len(task_prompt):
|
||||
raise ValueError("`task_prompt` and `content_prompt` must have the same length whe they are lists.")
|
||||
|
||||
for sample in image:
|
||||
if not isinstance(sample, list) or not isinstance(sample[0], list):
|
||||
raise ValueError("Each sample in the batch must have a 2D list of images.")
|
||||
if len({len(row) for row in sample}) != 1:
|
||||
raise ValueError("Each in-context example and query should contain the same number of images.")
|
||||
if not any(img is None for img in sample[-1]):
|
||||
raise ValueError("There are no targets in the query, which should be represented as None.")
|
||||
for row in sample[:-1]:
|
||||
if any(img is None for img in row):
|
||||
raise ValueError("Images are missing in in-context examples.")
|
||||
|
||||
# Validate embeddings
|
||||
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
||||
)
|
||||
|
||||
# Validate sequence length
|
||||
if max_sequence_length is not None and max_sequence_length > 512:
|
||||
raise ValueError(f"max_sequence_length cannot exceed 512, got {max_sequence_length}")
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
task_prompt: Union[str, List[str]] = None,
|
||||
content_prompt: Union[str, List[str]] = None,
|
||||
image: Optional[torch.FloatTensor] = None,
|
||||
upsampling_height: Optional[int] = None,
|
||||
upsampling_width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
guidance_scale: float = 30.0,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
upsampling_strength: float = 1.0,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the VisualCloze pipeline for generation.
|
||||
|
||||
Args:
|
||||
task_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to define the task intention.
|
||||
content_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to define the content or caption of the target image to be generated.
|
||||
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
||||
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
|
||||
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
|
||||
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
|
||||
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`.
|
||||
upsampling_height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image (i.e., output image) after upsampling via SDEdit. By
|
||||
default, the image is upsampled by a factor of three, and the base resolution is determined by the
|
||||
resolution parameter of the pipeline. When only one of `upsampling_height` or `upsampling_width` is
|
||||
specified, the other will be automatically set based on the aspect ratio.
|
||||
upsampling_width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image (i.e., output image) after upsampling via SDEdit. By
|
||||
default, the image is upsampled by a factor of three, and the base resolution is determined by the
|
||||
resolution parameter of the pipeline. When only one of `upsampling_height` or `upsampling_width` is
|
||||
specified, the other will be automatically set based on the aspect ratio.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
guidance_scale (`float`, *optional*, defaults to 30.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
||||
upsampling_strength (`float`, *optional*, defaults to 1.0):
|
||||
Indicates extent to transform the reference `image` when upsampling the results. Must be between 0 and
|
||||
1. The generated image is used as a starting point and more noise is added the higher the
|
||||
`upsampling_strength`. The number of denoising steps depends on the amount of noise initially added.
|
||||
When `upsampling_strength` is 1, added noise is maximum and the denoising process runs for the full
|
||||
number of iterations specified in `num_inference_steps`. A value of 0 skips the upsampling step and
|
||||
output the results at the resolution of `self.resolution`.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
||||
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
||||
images.
|
||||
"""
|
||||
|
||||
generation_output = self.generation_pipe(
|
||||
task_prompt=task_prompt,
|
||||
content_prompt=content_prompt,
|
||||
image=image,
|
||||
num_inference_steps=num_inference_steps,
|
||||
sigmas=sigmas,
|
||||
guidance_scale=guidance_scale,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
prompt_embeds=prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
callback_on_step_end=callback_on_step_end,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
max_sequence_length=max_sequence_length,
|
||||
output_type=output_type if upsampling_strength == 0 else "pil",
|
||||
)
|
||||
if upsampling_strength == 0:
|
||||
if not return_dict:
|
||||
return (generation_output,)
|
||||
|
||||
return FluxPipelineOutput(images=generation_output)
|
||||
|
||||
# Upsampling the generated images
|
||||
# 1. Prepare the input images and prompts
|
||||
if not isinstance(content_prompt, (list)):
|
||||
content_prompt = [content_prompt]
|
||||
n_target_per_sample = []
|
||||
upsampling_image = []
|
||||
upsampling_mask = []
|
||||
upsampling_prompt = []
|
||||
upsampling_generator = generator if isinstance(generator, (torch.Generator,)) else []
|
||||
for i in range(len(generation_output.images)):
|
||||
n_target_per_sample.append(len(generation_output.images[i]))
|
||||
for image in generation_output.images[i]:
|
||||
upsampling_image.append(image)
|
||||
upsampling_mask.append(Image.new("RGB", image.size, (255, 255, 255)))
|
||||
upsampling_prompt.append(
|
||||
content_prompt[i % len(content_prompt)] if content_prompt[i % len(content_prompt)] else ""
|
||||
)
|
||||
if not isinstance(generator, (torch.Generator,)):
|
||||
upsampling_generator.append(generator[i % len(content_prompt)])
|
||||
|
||||
# 2. Apply the denosing loop
|
||||
upsampling_output = self.upsampling_pipe(
|
||||
prompt=upsampling_prompt,
|
||||
image=upsampling_image,
|
||||
mask_image=upsampling_mask,
|
||||
height=upsampling_height,
|
||||
width=upsampling_width,
|
||||
strength=upsampling_strength,
|
||||
num_inference_steps=num_inference_steps,
|
||||
sigmas=sigmas,
|
||||
guidance_scale=guidance_scale,
|
||||
generator=upsampling_generator,
|
||||
output_type=output_type,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
callback_on_step_end=callback_on_step_end,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
image = upsampling_output.images
|
||||
|
||||
output = []
|
||||
if output_type == "pil":
|
||||
# Each sample in the batch may have multiple output images. When returning as PIL images,
|
||||
# these images cannot be concatenated. Therefore, for each sample,
|
||||
# a list is used to represent all the output images.
|
||||
output = []
|
||||
start = 0
|
||||
for n in n_target_per_sample:
|
||||
output.append(image[start : start + n])
|
||||
start += n
|
||||
else:
|
||||
output = image
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return FluxPipelineOutput(images=output)
|
||||
@@ -0,0 +1,952 @@
|
||||
# Copyright 2025 VisualCloze team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.transformers import FluxTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..flux.pipeline_flux_fill import calculate_shift, retrieve_latents, retrieve_timesteps
|
||||
from ..flux.pipeline_output import FluxPipelineOutput
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .visualcloze_utils import VisualClozeProcessor
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import VisualClozeGenerationPipeline, FluxFillPipeline as VisualClozeUpsamplingPipeline
|
||||
>>> from diffusers.utils import load_image
|
||||
>>> from PIL import Image
|
||||
|
||||
>>> image_paths = [
|
||||
... # in-context examples
|
||||
... [
|
||||
... load_image(
|
||||
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_mask.jpg"
|
||||
... ),
|
||||
... load_image(
|
||||
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_image.jpg"
|
||||
... ),
|
||||
... ],
|
||||
... # query with the target image
|
||||
... [
|
||||
... load_image(
|
||||
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_query_mask.jpg"
|
||||
... ),
|
||||
... None, # No image needed for the target image
|
||||
... ],
|
||||
... ]
|
||||
>>> task_prompt = "In each row, a logical task is demonstrated to achieve [IMAGE2] an aesthetically pleasing photograph based on [IMAGE1] sam 2-generated masks with rich color coding."
|
||||
>>> content_prompt = "Majestic photo of a golden eagle perched on a rocky outcrop in a mountainous landscape. The eagle is positioned in the right foreground, facing left, with its sharp beak and keen eyes prominently visible. Its plumage is a mix of dark brown and golden hues, with intricate feather details. The background features a soft-focus view of snow-capped mountains under a cloudy sky, creating a serene and grandiose atmosphere. The foreground includes rugged rocks and patches of green moss. Photorealistic, medium depth of field, soft natural lighting, cool color palette, high contrast, sharp focus on the eagle, blurred background, tranquil, majestic, wildlife photography."
|
||||
>>> pipe = VisualClozeGenerationPipeline.from_pretrained(
|
||||
... "VisualCloze/VisualClozePipeline-384", resolution=384, torch_dtype=torch.bfloat16
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> image = pipe(
|
||||
... task_prompt=task_prompt,
|
||||
... content_prompt=content_prompt,
|
||||
... image=image_paths,
|
||||
... guidance_scale=30,
|
||||
... num_inference_steps=30,
|
||||
... max_sequence_length=512,
|
||||
... generator=torch.Generator("cpu").manual_seed(0),
|
||||
... ).images[0][0]
|
||||
|
||||
>>> # optional, upsampling the generated image
|
||||
>>> pipe_upsample = VisualClozeUpsamplingPipeline.from_pipe(pipe)
|
||||
>>> pipe_upsample.to("cuda")
|
||||
|
||||
>>> mask_image = Image.new("RGB", image.size, (255, 255, 255))
|
||||
|
||||
>>> image = pipe_upsample(
|
||||
... image=image,
|
||||
... mask_image=mask_image,
|
||||
... prompt=content_prompt,
|
||||
... width=1344,
|
||||
... height=768,
|
||||
... strength=0.4,
|
||||
... guidance_scale=30,
|
||||
... num_inference_steps=30,
|
||||
... max_sequence_length=512,
|
||||
... generator=torch.Generator("cpu").manual_seed(0),
|
||||
... ).images[0]
|
||||
|
||||
>>> image.save("visualcloze.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class VisualClozeGenerationPipeline(
|
||||
DiffusionPipeline,
|
||||
FluxLoraLoaderMixin,
|
||||
FromSingleFileMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
):
|
||||
r"""
|
||||
The VisualCloze pipeline for image generation with visual context. Reference:
|
||||
https://github.com/lzyhha/VisualCloze/tree/main This pipeline is designed to generate images based on visual
|
||||
in-context examples.
|
||||
|
||||
Args:
|
||||
transformer ([`FluxTransformer2DModel`]):
|
||||
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
text_encoder_2 ([`T5EncoderModel`]):
|
||||
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
||||
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
tokenizer_2 (`T5TokenizerFast`):
|
||||
Second Tokenizer of class
|
||||
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
||||
resolution (`int`, *optional*, defaults to 384):
|
||||
The resolution of each image when concatenating images from the query and in-context examples.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
||||
_optional_components = []
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder_2: T5EncoderModel,
|
||||
tokenizer_2: T5TokenizerFast,
|
||||
transformer: FluxTransformer2DModel,
|
||||
resolution: int = 384,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_2=tokenizer_2,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.resolution = resolution
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
||||
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
|
||||
self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
|
||||
self.image_processor = VisualClozeProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels, resolution=resolution
|
||||
)
|
||||
self.tokenizer_max_length = (
|
||||
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
||||
)
|
||||
self.default_sample_size = 128
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
max_sequence_length: int = 512,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
|
||||
|
||||
text_inputs = self.tokenizer_2(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_length=False,
|
||||
return_overflowing_tokens=False,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
|
||||
|
||||
dtype = self.text_encoder_2.dtype
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
|
||||
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
|
||||
def _get_clip_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
num_images_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer_max_length,
|
||||
truncation=True,
|
||||
return_overflowing_tokens=False,
|
||||
return_length=False,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer_max_length} tokens: {removed_text}"
|
||||
)
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
||||
|
||||
# Use pooled output of CLIPTextModel
|
||||
prompt_embeds = prompt_embeds.pooler_output
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Modified from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
layout_prompt: Union[str, List[str]],
|
||||
task_prompt: Union[str, List[str]],
|
||||
content_prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
max_sequence_length: int = 512,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
|
||||
Args:
|
||||
layout_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to define the number of in-context examples and the number of images involved in
|
||||
the task.
|
||||
task_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to define the task intention.
|
||||
content_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to define the content or caption of the target image to be generated.
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
lora_scale (`float`, *optional*):
|
||||
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
|
||||
scale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
if isinstance(layout_prompt, str):
|
||||
layout_prompt = [layout_prompt]
|
||||
task_prompt = [task_prompt]
|
||||
content_prompt = [content_prompt]
|
||||
|
||||
def _preprocess(prompt, content=False):
|
||||
if prompt is not None:
|
||||
return f"The last image of the last row depicts: {prompt}" if content else prompt
|
||||
else:
|
||||
return ""
|
||||
|
||||
prompt = [
|
||||
f"{_preprocess(layout_prompt[i])} {_preprocess(task_prompt[i])} {_preprocess(content_prompt[i], content=True)}".strip()
|
||||
for i in range(len(layout_prompt))
|
||||
]
|
||||
pooled_prompt_embeds = self._get_clip_prompt_embeds(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
)
|
||||
prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if self.text_encoder is not None:
|
||||
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
||||
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds, text_ids
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
|
||||
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
|
||||
for i in range(image.shape[0])
|
||||
]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
|
||||
|
||||
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
|
||||
return image_latents
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
|
||||
def get_timesteps(self, num_inference_steps, strength, device):
|
||||
# get the original timestep using init_timestep
|
||||
init_timestep = min(num_inference_steps * strength, num_inference_steps)
|
||||
|
||||
t_start = int(max(num_inference_steps - init_timestep, 0))
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
image,
|
||||
task_prompt,
|
||||
content_prompt,
|
||||
prompt_embeds=None,
|
||||
pooled_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
max_sequence_length=None,
|
||||
):
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
# Validate prompt inputs
|
||||
if (task_prompt is not None or content_prompt is not None) and prompt_embeds is not None:
|
||||
raise ValueError("Cannot provide both text `task_prompt` + `content_prompt` and `prompt_embeds`. ")
|
||||
|
||||
if task_prompt is None and content_prompt is None and prompt_embeds is None:
|
||||
raise ValueError("Must provide either `task_prompt` + `content_prompt` or pre-computed `prompt_embeds`. ")
|
||||
|
||||
# Validate prompt types and consistency
|
||||
if task_prompt is None:
|
||||
raise ValueError("`task_prompt` is missing.")
|
||||
|
||||
if task_prompt is not None and not isinstance(task_prompt, (str, list)):
|
||||
raise ValueError(f"`task_prompt` must be str or list, got {type(task_prompt)}")
|
||||
|
||||
if content_prompt is not None and not isinstance(content_prompt, (str, list)):
|
||||
raise ValueError(f"`content_prompt` must be str or list, got {type(content_prompt)}")
|
||||
|
||||
if isinstance(task_prompt, list) or isinstance(content_prompt, list):
|
||||
if not isinstance(task_prompt, list) or not isinstance(content_prompt, list):
|
||||
raise ValueError(
|
||||
f"`task_prompt` and `content_prompt` must both be lists, or both be of type str or None, "
|
||||
f"got {type(task_prompt)} and {type(content_prompt)}"
|
||||
)
|
||||
if len(content_prompt) != len(task_prompt):
|
||||
raise ValueError("`task_prompt` and `content_prompt` must have the same length whe they are lists.")
|
||||
|
||||
for sample in image:
|
||||
if not isinstance(sample, list) or not isinstance(sample[0], list):
|
||||
raise ValueError("Each sample in the batch must have a 2D list of images.")
|
||||
if len({len(row) for row in sample}) != 1:
|
||||
raise ValueError("Each in-context example and query should contain the same number of images.")
|
||||
if not any(img is None for img in sample[-1]):
|
||||
raise ValueError("There are no targets in the query, which should be represented as None.")
|
||||
for row in sample[:-1]:
|
||||
if any(img is None for img in row):
|
||||
raise ValueError("Images are missing in in-context examples.")
|
||||
|
||||
# Validate embeddings
|
||||
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
||||
)
|
||||
|
||||
# Validate sequence length
|
||||
if max_sequence_length is not None and max_sequence_length > 512:
|
||||
raise ValueError(f"max_sequence_length cannot exceed 512, got {max_sequence_length}")
|
||||
|
||||
@staticmethod
|
||||
def _prepare_latent_image_ids(image, vae_scale_factor, device, dtype):
|
||||
latent_image_ids = []
|
||||
|
||||
for idx, img in enumerate(image, start=1):
|
||||
img = img.squeeze(0)
|
||||
channels, height, width = img.shape
|
||||
|
||||
num_patches_h = height // vae_scale_factor // 2
|
||||
num_patches_w = width // vae_scale_factor // 2
|
||||
|
||||
patch_ids = torch.zeros(num_patches_h, num_patches_w, 3, device=device, dtype=dtype)
|
||||
patch_ids[..., 0] = idx
|
||||
patch_ids[..., 1] = torch.arange(num_patches_h, device=device, dtype=dtype)[:, None]
|
||||
patch_ids[..., 2] = torch.arange(num_patches_w, device=device, dtype=dtype)[None, :]
|
||||
|
||||
patch_ids = patch_ids.reshape(-1, 3)
|
||||
latent_image_ids.append(patch_ids)
|
||||
|
||||
return torch.cat(latent_image_ids, dim=0)
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
|
||||
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
||||
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
||||
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
||||
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def _unpack_latents(latents, sizes, vae_scale_factor):
|
||||
batch_size, num_patches, channels = latents.shape
|
||||
|
||||
start = 0
|
||||
unpacked_latents = []
|
||||
for i in range(len(sizes)):
|
||||
cur_size = sizes[i]
|
||||
height = cur_size[0][0] // vae_scale_factor
|
||||
width = sum([size[1] for size in cur_size]) // vae_scale_factor
|
||||
|
||||
end = start + (height * width) // 4
|
||||
|
||||
cur_latents = latents[:, start:end]
|
||||
cur_latents = cur_latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
|
||||
cur_latents = cur_latents.permute(0, 3, 1, 4, 2, 5)
|
||||
cur_latents = cur_latents.reshape(batch_size, channels // (2 * 2), height, width)
|
||||
|
||||
unpacked_latents.append(cur_latents)
|
||||
|
||||
start = end
|
||||
|
||||
return unpacked_latents
|
||||
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.vae.enable_slicing()
|
||||
|
||||
def disable_vae_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_slicing()
|
||||
|
||||
def enable_vae_tiling(self):
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
self.vae.enable_tiling()
|
||||
|
||||
def disable_vae_tiling(self):
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
def _prepare_latents(self, image, mask, gen, vae_scale_factor, device, dtype):
|
||||
"""Helper function to prepare latents for a single batch."""
|
||||
# Concatenate images and masks along width dimension
|
||||
image = [torch.cat(img, dim=3).to(device=device, dtype=dtype) for img in image]
|
||||
mask = [torch.cat(m, dim=3).to(device=device, dtype=dtype) for m in mask]
|
||||
|
||||
# Generate latent image IDs
|
||||
latent_image_ids = self._prepare_latent_image_ids(image, vae_scale_factor, device, dtype)
|
||||
|
||||
# For initial encoding, use actual images
|
||||
image_latent = [self._encode_vae_image(img, gen) for img in image]
|
||||
masked_image_latent = [img.clone() for img in image_latent]
|
||||
|
||||
for i in range(len(image_latent)):
|
||||
# Rearrange latents and masks for patch processing
|
||||
num_channels_latents, height, width = image_latent[i].shape[1:]
|
||||
image_latent[i] = self._pack_latents(image_latent[i], 1, num_channels_latents, height, width)
|
||||
masked_image_latent[i] = self._pack_latents(masked_image_latent[i], 1, num_channels_latents, height, width)
|
||||
|
||||
# Rearrange masks for patch processing
|
||||
num_channels_latents, height, width = mask[i].shape[1:]
|
||||
mask[i] = mask[i].view(
|
||||
1,
|
||||
num_channels_latents,
|
||||
height // vae_scale_factor,
|
||||
vae_scale_factor,
|
||||
width // vae_scale_factor,
|
||||
vae_scale_factor,
|
||||
)
|
||||
mask[i] = mask[i].permute(0, 1, 3, 5, 2, 4)
|
||||
mask[i] = mask[i].reshape(
|
||||
1,
|
||||
num_channels_latents * (vae_scale_factor**2),
|
||||
height // vae_scale_factor,
|
||||
width // vae_scale_factor,
|
||||
)
|
||||
mask[i] = self._pack_latents(
|
||||
mask[i],
|
||||
1,
|
||||
num_channels_latents * (vae_scale_factor**2),
|
||||
height // vae_scale_factor,
|
||||
width // vae_scale_factor,
|
||||
)
|
||||
|
||||
# Concatenate along batch dimension
|
||||
image_latent = torch.cat(image_latent, dim=1)
|
||||
masked_image_latent = torch.cat(masked_image_latent, dim=1)
|
||||
mask = torch.cat(mask, dim=1)
|
||||
|
||||
return image_latent, masked_image_latent, mask, latent_image_ids
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
input_image,
|
||||
input_mask,
|
||||
timestep,
|
||||
batch_size,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
vae_scale_factor,
|
||||
):
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
# Process each batch
|
||||
masked_image_latents = []
|
||||
image_latents = []
|
||||
masks = []
|
||||
latent_image_ids = []
|
||||
|
||||
for i in range(len(input_image)):
|
||||
_image_latent, _masked_image_latent, _mask, _latent_image_ids = self._prepare_latents(
|
||||
input_image[i],
|
||||
input_mask[i],
|
||||
generator if isinstance(generator, torch.Generator) else generator[i],
|
||||
vae_scale_factor,
|
||||
device,
|
||||
dtype,
|
||||
)
|
||||
masked_image_latents.append(_masked_image_latent)
|
||||
image_latents.append(_image_latent)
|
||||
masks.append(_mask)
|
||||
latent_image_ids.append(_latent_image_ids)
|
||||
|
||||
# Concatenate all batches
|
||||
masked_image_latents = torch.cat(masked_image_latents, dim=0)
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
masks = torch.cat(masks, dim=0)
|
||||
|
||||
# Handle batch size expansion
|
||||
if batch_size > masked_image_latents.shape[0]:
|
||||
if batch_size % masked_image_latents.shape[0] == 0:
|
||||
# Expand batches by repeating
|
||||
additional_image_per_prompt = batch_size // masked_image_latents.shape[0]
|
||||
masked_image_latents = torch.cat([masked_image_latents] * additional_image_per_prompt, dim=0)
|
||||
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
|
||||
masks = torch.cat([masks] * additional_image_per_prompt, dim=0)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot expand batch size from {masked_image_latents.shape[0]} to {batch_size}. "
|
||||
"Batch sizes must be multiples of each other."
|
||||
)
|
||||
|
||||
# Add noise to latents
|
||||
noises = randn_tensor(image_latents.shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = self.scheduler.scale_noise(image_latents, timestep, noises).to(dtype=dtype)
|
||||
|
||||
# Combine masked latents with masks
|
||||
masked_image_latents = torch.cat((masked_image_latents, masks), dim=-1).to(dtype=dtype)
|
||||
|
||||
return latents, masked_image_latents, latent_image_ids[0]
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def joint_attention_kwargs(self):
|
||||
return self._joint_attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
task_prompt: Union[str, List[str]] = None,
|
||||
content_prompt: Union[str, List[str]] = None,
|
||||
image: Optional[torch.FloatTensor] = None,
|
||||
num_inference_steps: int = 50,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
guidance_scale: float = 30.0,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the VisualCloze pipeline for generation.
|
||||
|
||||
Args:
|
||||
task_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to define the task intention.
|
||||
content_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to define the content or caption of the target image to be generated.
|
||||
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
||||
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
|
||||
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
|
||||
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
|
||||
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
guidance_scale (`float`, *optional*, defaults to 30.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
||||
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
||||
images.
|
||||
"""
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
image,
|
||||
task_prompt,
|
||||
content_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._joint_attention_kwargs = joint_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
processor_output = self.image_processor.preprocess(
|
||||
task_prompt, content_prompt, image, vae_scale_factor=self.vae_scale_factor
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
if processor_output["task_prompt"] is not None and isinstance(processor_output["task_prompt"], str):
|
||||
batch_size = 1
|
||||
elif processor_output["task_prompt"] is not None and isinstance(processor_output["task_prompt"], list):
|
||||
batch_size = len(processor_output["task_prompt"])
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 3. Prepare prompt embeddings
|
||||
lora_scale = (
|
||||
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
||||
)
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
|
||||
layout_prompt=processor_output["layout_prompt"],
|
||||
task_prompt=processor_output["task_prompt"],
|
||||
content_prompt=processor_output["content_prompt"],
|
||||
prompt_embeds=prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
# Calculate sequence length and shift factor
|
||||
image_seq_len = sum(
|
||||
(size[0] // self.vae_scale_factor // 2) * (size[1] // self.vae_scale_factor // 2)
|
||||
for sample in processor_output["image_size"][0]
|
||||
for size in sample
|
||||
)
|
||||
|
||||
# Calculate noise schedule parameters
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.scheduler.config.get("base_image_seq_len", 256),
|
||||
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||
self.scheduler.config.get("base_shift", 0.5),
|
||||
self.scheduler.config.get("max_shift", 1.15),
|
||||
)
|
||||
|
||||
# Get timesteps
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, 1.0, device)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||
latents, masked_image_latents, latent_image_ids = self.prepare_latents(
|
||||
processor_output["init_image"],
|
||||
processor_output["mask"],
|
||||
latent_timestep,
|
||||
batch_size * num_images_per_prompt,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
vae_scale_factor=self.vae_scale_factor,
|
||||
)
|
||||
|
||||
# Calculate warmup steps
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# Prepare guidance
|
||||
if self.transformer.config.guidance_embeds:
|
||||
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
||||
guidance = guidance.expand(latents.shape[0])
|
||||
else:
|
||||
guidance = None
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# Broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
latent_model_input = torch.cat((latents, masked_image_latents), dim=2)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep / 1000,
|
||||
guidance=guidance,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# Compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# Some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
# Call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
# XLA optimization
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
# 7. Post-process the image
|
||||
# Crop the target image
|
||||
# Since the generated image is a concatenation of the conditional and target regions,
|
||||
# we need to extract only the target regions based on their positions
|
||||
image = []
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
else:
|
||||
for b in range(len(latents)):
|
||||
cur_image_size = processor_output["image_size"][b % batch_size]
|
||||
cur_target_position = processor_output["target_position"][b % batch_size]
|
||||
cur_latent = self._unpack_latents(latents[b].unsqueeze(0), cur_image_size, self.vae_scale_factor)[-1]
|
||||
cur_latent = (cur_latent / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||
cur_image = self.vae.decode(cur_latent, return_dict=False)[0]
|
||||
cur_image = self.image_processor.postprocess(cur_image, output_type=output_type)[0]
|
||||
|
||||
start = 0
|
||||
cropped = []
|
||||
for i, size in enumerate(cur_image_size[-1]):
|
||||
if cur_target_position[i]:
|
||||
if output_type == "pil":
|
||||
cropped.append(cur_image.crop((start, 0, start + size[1], size[0])))
|
||||
else:
|
||||
cropped.append(cur_image[0 : size[0], start : start + size[1]])
|
||||
start += size[1]
|
||||
image.append(cropped)
|
||||
if output_type != "pil":
|
||||
image = np.concatenate([arr[None] for sub_image in image for arr in sub_image], axis=0)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return FluxPipelineOutput(images=image)
|
||||
@@ -0,0 +1,251 @@
|
||||
# Copyright 2025 VisualCloze team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
|
||||
|
||||
class VisualClozeProcessor(VaeImageProcessor):
|
||||
"""
|
||||
Image processor for the VisualCloze pipeline.
|
||||
|
||||
This processor handles the preprocessing of images for visual cloze tasks, including resizing, normalization, and
|
||||
mask generation.
|
||||
|
||||
Args:
|
||||
resolution (int, optional):
|
||||
Target resolution for processing images. Each image will be resized to this resolution before being
|
||||
concatenated to avoid the out-of-memory error. Defaults to 384.
|
||||
*args: Additional arguments passed to [~image_processor.VaeImageProcessor]
|
||||
**kwargs: Additional keyword arguments passed to [~image_processor.VaeImageProcessor]
|
||||
"""
|
||||
|
||||
def __init__(self, *args, resolution: int = 384, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.resolution = resolution
|
||||
|
||||
def preprocess_image(
|
||||
self, input_images: List[List[Optional[Image.Image]]], vae_scale_factor: int
|
||||
) -> Tuple[List[List[torch.Tensor]], List[List[List[int]]], List[int]]:
|
||||
"""
|
||||
Preprocesses input images for the VisualCloze pipeline.
|
||||
|
||||
This function handles the preprocessing of input images by:
|
||||
1. Resizing and cropping images to maintain consistent dimensions
|
||||
2. Converting images to the Tensor format for the VAE
|
||||
3. Normalizing pixel values
|
||||
4. Tracking image sizes and positions of target images
|
||||
|
||||
Args:
|
||||
input_images (List[List[Optional[Image.Image]]]):
|
||||
A nested list of PIL Images where:
|
||||
- Outer list represents different samples, including in-context examples and the query
|
||||
- Inner list contains images for the task
|
||||
- In the last row, condition images are provided and the target images are placed as None
|
||||
vae_scale_factor (int):
|
||||
The scale factor used by the VAE for resizing images
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- List[List[torch.Tensor]]: Preprocessed images in tensor format
|
||||
- List[List[List[int]]]: Dimensions of each processed image [height, width]
|
||||
- List[int]: Target positions indicating which images are to be generated
|
||||
"""
|
||||
n_samples, n_task_images = len(input_images), len(input_images[0])
|
||||
divisible = 2 * vae_scale_factor
|
||||
|
||||
processed_images: List[List[Image.Image]] = [[] for _ in range(n_samples)]
|
||||
resize_size: List[Optional[Tuple[int, int]]] = [None for _ in range(n_samples)]
|
||||
target_position: List[int] = []
|
||||
|
||||
# Process each sample
|
||||
for i in range(n_samples):
|
||||
# Determine size from first non-None image
|
||||
for j in range(n_task_images):
|
||||
if input_images[i][j] is not None:
|
||||
aspect_ratio = input_images[i][j].width / input_images[i][j].height
|
||||
target_area = self.resolution * self.resolution
|
||||
new_h = int((target_area / aspect_ratio) ** 0.5)
|
||||
new_w = int(new_h * aspect_ratio)
|
||||
|
||||
new_w = max(new_w // divisible, 1) * divisible
|
||||
new_h = max(new_h // divisible, 1) * divisible
|
||||
resize_size[i] = (new_w, new_h)
|
||||
break
|
||||
|
||||
# Process all images in the sample
|
||||
for j in range(n_task_images):
|
||||
if input_images[i][j] is not None:
|
||||
target = self._resize_and_crop(input_images[i][j], resize_size[i][0], resize_size[i][1])
|
||||
processed_images[i].append(target)
|
||||
if i == n_samples - 1:
|
||||
target_position.append(0)
|
||||
else:
|
||||
blank = Image.new("RGB", resize_size[i] or (self.resolution, self.resolution), (0, 0, 0))
|
||||
processed_images[i].append(blank)
|
||||
if i == n_samples - 1:
|
||||
target_position.append(1)
|
||||
|
||||
# Ensure consistent width for multiple target images when there are multiple target images
|
||||
if len(target_position) > 1 and sum(target_position) > 1:
|
||||
new_w = resize_size[n_samples - 1][0] or 384
|
||||
for i in range(len(processed_images)):
|
||||
for j in range(len(processed_images[i])):
|
||||
if processed_images[i][j] is not None:
|
||||
new_h = int(processed_images[i][j].height * (new_w / processed_images[i][j].width))
|
||||
new_w = int(new_w / 16) * 16
|
||||
new_h = int(new_h / 16) * 16
|
||||
processed_images[i][j] = self.height(processed_images[i][j], new_h, new_w)
|
||||
|
||||
# Convert to tensors and normalize
|
||||
image_sizes = []
|
||||
for i in range(len(processed_images)):
|
||||
image_sizes.append([[img.height, img.width] for img in processed_images[i]])
|
||||
for j, image in enumerate(processed_images[i]):
|
||||
image = self.pil_to_numpy(image)
|
||||
image = self.numpy_to_pt(image)
|
||||
image = self.normalize(image)
|
||||
processed_images[i][j] = image
|
||||
|
||||
return processed_images, image_sizes, target_position
|
||||
|
||||
def preprocess_mask(
|
||||
self, input_images: List[List[Image.Image]], target_position: List[int]
|
||||
) -> List[List[torch.Tensor]]:
|
||||
"""
|
||||
Generate masks for the VisualCloze pipeline.
|
||||
|
||||
Args:
|
||||
input_images (List[List[Image.Image]]):
|
||||
Processed images from preprocess_image
|
||||
target_position (List[int]):
|
||||
Binary list marking the positions of target images (1 for target, 0 for condition)
|
||||
|
||||
Returns:
|
||||
List[List[torch.Tensor]]:
|
||||
A nested list of mask tensors (1 for target positions, 0 for condition images)
|
||||
"""
|
||||
mask = []
|
||||
for i, row in enumerate(input_images):
|
||||
if i == len(input_images) - 1: # Query row
|
||||
row_masks = [
|
||||
torch.full((1, 1, row[0].shape[2], row[0].shape[3]), fill_value=m) for m in target_position
|
||||
]
|
||||
else: # In-context examples
|
||||
row_masks = [
|
||||
torch.full((1, 1, row[0].shape[2], row[0].shape[3]), fill_value=0) for _ in target_position
|
||||
]
|
||||
mask.append(row_masks)
|
||||
return mask
|
||||
|
||||
def preprocess_image_upsampling(
|
||||
self,
|
||||
input_images: List[List[Image.Image]],
|
||||
height: int,
|
||||
width: int,
|
||||
) -> Tuple[List[List[Image.Image]], List[List[List[int]]]]:
|
||||
"""Process images for the upsampling stage in the VisualCloze pipeline.
|
||||
|
||||
Args:
|
||||
input_images: Input image to process
|
||||
height: Target height
|
||||
width: Target width
|
||||
|
||||
Returns:
|
||||
Tuple of processed image and its size
|
||||
"""
|
||||
image = self.resize(input_images[0][0], height, width)
|
||||
image = self.pil_to_numpy(image) # to np
|
||||
image = self.numpy_to_pt(image) # to pt
|
||||
image = self.normalize(image)
|
||||
|
||||
input_images[0][0] = image
|
||||
image_sizes = [[[height, width]]]
|
||||
return input_images, image_sizes
|
||||
|
||||
def preprocess_mask_upsampling(self, input_images: List[List[Image.Image]]) -> List[List[torch.Tensor]]:
|
||||
return [[torch.ones((1, 1, input_images[0][0].shape[2], input_images[0][0].shape[3]))]]
|
||||
|
||||
def get_layout_prompt(self, size: Tuple[int, int]) -> str:
|
||||
layout_instruction = (
|
||||
f"A grid layout with {size[0]} rows and {size[1]} columns, displaying {size[0] * size[1]} images arranged side by side.",
|
||||
)
|
||||
return layout_instruction
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
task_prompt: Union[str, List[str]],
|
||||
content_prompt: Union[str, List[str]],
|
||||
input_images: Optional[List[List[List[Optional[str]]]]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
upsampling: bool = False,
|
||||
vae_scale_factor: int = 16,
|
||||
) -> Dict:
|
||||
"""Process visual cloze inputs.
|
||||
|
||||
Args:
|
||||
task_prompt: Task description(s)
|
||||
content_prompt: Content description(s)
|
||||
input_images: List of images or None for the target images
|
||||
height: Optional target height for upsampling stage
|
||||
width: Optional target width for upsampling stage
|
||||
upsampling: Whether this is in the upsampling processing stage
|
||||
|
||||
Returns:
|
||||
Dictionary containing processed images, masks, prompts and metadata
|
||||
"""
|
||||
if isinstance(task_prompt, str):
|
||||
task_prompt = [task_prompt]
|
||||
content_prompt = [content_prompt]
|
||||
input_images = [input_images]
|
||||
|
||||
output = {
|
||||
"init_image": [],
|
||||
"mask": [],
|
||||
"task_prompt": task_prompt if not upsampling else [None for _ in range(len(task_prompt))],
|
||||
"content_prompt": content_prompt,
|
||||
"layout_prompt": [],
|
||||
"target_position": [],
|
||||
"image_size": [],
|
||||
}
|
||||
for i in range(len(task_prompt)):
|
||||
if upsampling:
|
||||
layout_prompt = None
|
||||
else:
|
||||
layout_prompt = self.get_layout_prompt((len(input_images[i]), len(input_images[i][0])))
|
||||
|
||||
if upsampling:
|
||||
cur_processed_images, cur_image_size = self.preprocess_image_upsampling(
|
||||
input_images[i], height=height, width=width
|
||||
)
|
||||
cur_mask = self.preprocess_mask_upsampling(cur_processed_images)
|
||||
else:
|
||||
cur_processed_images, cur_image_size, cur_target_position = self.preprocess_image(
|
||||
input_images[i], vae_scale_factor=vae_scale_factor
|
||||
)
|
||||
cur_mask = self.preprocess_mask(cur_processed_images, cur_target_position)
|
||||
|
||||
output["target_position"].append(cur_target_position)
|
||||
|
||||
output["image_size"].append(cur_image_size)
|
||||
output["init_image"].append(cur_processed_images)
|
||||
output["mask"].append(cur_mask)
|
||||
output["layout_prompt"].append(layout_prompt)
|
||||
|
||||
return output
|
||||
@@ -12,5 +12,183 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from ..utils import is_transformers_available, logging
|
||||
from .auto import DiffusersAutoQuantizer
|
||||
from .base import DiffusersQuantizer
|
||||
from .quantization_config import QuantizationConfigMixin as DiffQuantConfigMixin
|
||||
|
||||
|
||||
try:
|
||||
from transformers.utils.quantization_config import QuantizationConfigMixin as TransformersQuantConfigMixin
|
||||
except ImportError:
|
||||
|
||||
class TransformersQuantConfigMixin:
|
||||
pass
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class PipelineQuantizationConfig:
|
||||
"""
|
||||
Configuration class to be used when applying quantization on-the-fly to [`~DiffusionPipeline.from_pretrained`].
|
||||
|
||||
Args:
|
||||
quant_backend (`str`): Quantization backend to be used. When using this option, we assume that the backend
|
||||
is available to both `diffusers` and `transformers`.
|
||||
quant_kwargs (`dict`): Params to initialize the quantization backend class.
|
||||
components_to_quantize (`list`): Components of a pipeline to be quantized.
|
||||
quant_mapping (`dict`): Mapping defining the quantization specs to be used for the pipeline
|
||||
components. When using this argument, users are not expected to provide `quant_backend`, `quant_kawargs`,
|
||||
and `components_to_quantize`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_backend: str = None,
|
||||
quant_kwargs: Dict[str, Union[str, float, int, dict]] = None,
|
||||
components_to_quantize: Optional[List[str]] = None,
|
||||
quant_mapping: Dict[str, Union[DiffQuantConfigMixin, "TransformersQuantConfigMixin"]] = None,
|
||||
):
|
||||
self.quant_backend = quant_backend
|
||||
# Initialize kwargs to be {} to set to the defaults.
|
||||
self.quant_kwargs = quant_kwargs or {}
|
||||
self.components_to_quantize = components_to_quantize
|
||||
self.quant_mapping = quant_mapping
|
||||
|
||||
self.post_init()
|
||||
|
||||
def post_init(self):
|
||||
quant_mapping = self.quant_mapping
|
||||
self.is_granular = True if quant_mapping is not None else False
|
||||
|
||||
self._validate_init_args()
|
||||
|
||||
def _validate_init_args(self):
|
||||
if self.quant_backend and self.quant_mapping:
|
||||
raise ValueError("Both `quant_backend` and `quant_mapping` cannot be specified at the same time.")
|
||||
|
||||
if not self.quant_mapping and not self.quant_backend:
|
||||
raise ValueError("Must provide a `quant_backend` when not providing a `quant_mapping`.")
|
||||
|
||||
if not self.quant_kwargs and not self.quant_mapping:
|
||||
raise ValueError("Both `quant_kwargs` and `quant_mapping` cannot be None.")
|
||||
|
||||
if self.quant_backend is not None:
|
||||
self._validate_init_kwargs_in_backends()
|
||||
|
||||
if self.quant_mapping is not None:
|
||||
self._validate_quant_mapping_args()
|
||||
|
||||
def _validate_init_kwargs_in_backends(self):
|
||||
quant_backend = self.quant_backend
|
||||
|
||||
self._check_backend_availability(quant_backend)
|
||||
|
||||
quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
|
||||
|
||||
if quant_config_mapping_transformers is not None:
|
||||
init_kwargs_transformers = inspect.signature(quant_config_mapping_transformers[quant_backend].__init__)
|
||||
init_kwargs_transformers = {name for name in init_kwargs_transformers.parameters if name != "self"}
|
||||
else:
|
||||
init_kwargs_transformers = None
|
||||
|
||||
init_kwargs_diffusers = inspect.signature(quant_config_mapping_diffusers[quant_backend].__init__)
|
||||
init_kwargs_diffusers = {name for name in init_kwargs_diffusers.parameters if name != "self"}
|
||||
|
||||
if init_kwargs_transformers != init_kwargs_diffusers:
|
||||
raise ValueError(
|
||||
"The signatures of the __init__ methods of the quantization config classes in `diffusers` and `transformers` don't match. "
|
||||
f"Please provide a `quant_mapping` instead, in the {self.__class__.__name__} class. Refer to [the docs](https://huggingface.co/docs/diffusers/main/en/quantization/overview#pipeline-level-quantization) to learn more about how "
|
||||
"this mapping would look like."
|
||||
)
|
||||
|
||||
def _validate_quant_mapping_args(self):
|
||||
quant_mapping = self.quant_mapping
|
||||
transformers_map, diffusers_map = self._get_quant_config_list()
|
||||
|
||||
available_transformers = list(transformers_map.values()) if transformers_map else None
|
||||
available_diffusers = list(diffusers_map.values())
|
||||
|
||||
for module_name, config in quant_mapping.items():
|
||||
if any(isinstance(config, cfg) for cfg in available_diffusers):
|
||||
continue
|
||||
|
||||
if available_transformers and any(isinstance(config, cfg) for cfg in available_transformers):
|
||||
continue
|
||||
|
||||
if available_transformers:
|
||||
raise ValueError(
|
||||
f"Provided config for module_name={module_name} could not be found. "
|
||||
f"Available diffusers configs: {available_diffusers}; "
|
||||
f"Available transformers configs: {available_transformers}."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Provided config for module_name={module_name} could not be found. "
|
||||
f"Available diffusers configs: {available_diffusers}."
|
||||
)
|
||||
|
||||
def _check_backend_availability(self, quant_backend: str):
|
||||
quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
|
||||
|
||||
available_backends_transformers = (
|
||||
list(quant_config_mapping_transformers.keys()) if quant_config_mapping_transformers else None
|
||||
)
|
||||
available_backends_diffusers = list(quant_config_mapping_diffusers.keys())
|
||||
|
||||
if (
|
||||
available_backends_transformers and quant_backend not in available_backends_transformers
|
||||
) or quant_backend not in quant_config_mapping_diffusers:
|
||||
error_message = f"Provided quant_backend={quant_backend} was not found."
|
||||
if available_backends_transformers:
|
||||
error_message += f"\nAvailable ones (transformers): {available_backends_transformers}."
|
||||
error_message += f"\nAvailable ones (diffusers): {available_backends_diffusers}."
|
||||
raise ValueError(error_message)
|
||||
|
||||
def _resolve_quant_config(self, is_diffusers: bool = True, module_name: str = None):
|
||||
quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
|
||||
|
||||
quant_mapping = self.quant_mapping
|
||||
components_to_quantize = self.components_to_quantize
|
||||
|
||||
# Granular case
|
||||
if self.is_granular and module_name in quant_mapping:
|
||||
logger.debug(f"Initializing quantization config class for {module_name}.")
|
||||
config = quant_mapping[module_name]
|
||||
return config
|
||||
|
||||
# Global config case
|
||||
else:
|
||||
should_quantize = False
|
||||
# Only quantize the modules requested for.
|
||||
if components_to_quantize and module_name in components_to_quantize:
|
||||
should_quantize = True
|
||||
# No specification for `components_to_quantize` means all modules should be quantized.
|
||||
elif not self.is_granular and not components_to_quantize:
|
||||
should_quantize = True
|
||||
|
||||
if should_quantize:
|
||||
logger.debug(f"Initializing quantization config class for {module_name}.")
|
||||
mapping_to_use = quant_config_mapping_diffusers if is_diffusers else quant_config_mapping_transformers
|
||||
quant_config_cls = mapping_to_use[self.quant_backend]
|
||||
quant_kwargs = self.quant_kwargs
|
||||
return quant_config_cls(**quant_kwargs)
|
||||
|
||||
# Fallback: no applicable configuration found.
|
||||
return None
|
||||
|
||||
def _get_quant_config_list(self):
|
||||
if is_transformers_available():
|
||||
from transformers.quantizers.auto import (
|
||||
AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_transformers,
|
||||
)
|
||||
else:
|
||||
quant_config_mapping_transformers = None
|
||||
|
||||
from ..quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_diffusers
|
||||
|
||||
return quant_config_mapping_transformers, quant_config_mapping_diffusers
|
||||
|
||||
@@ -408,6 +408,18 @@ class GGUFParameter(torch.nn.Parameter):
|
||||
def as_tensor(self):
|
||||
return torch.Tensor._make_subclass(torch.Tensor, self, self.requires_grad)
|
||||
|
||||
@staticmethod
|
||||
def _extract_quant_type(args):
|
||||
# When converting from original format checkpoints we often use splits, cats etc on tensors
|
||||
# this method ensures that the returned tensor type from those operations remains GGUFParameter
|
||||
# so that we preserve quant_type information
|
||||
for arg in args:
|
||||
if isinstance(arg, list) and isinstance(arg[0], GGUFParameter):
|
||||
return arg[0].quant_type
|
||||
if isinstance(arg, GGUFParameter):
|
||||
return arg.quant_type
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
@@ -415,22 +427,13 @@ class GGUFParameter(torch.nn.Parameter):
|
||||
|
||||
result = super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
# When converting from original format checkpoints we often use splits, cats etc on tensors
|
||||
# this method ensures that the returned tensor type from those operations remains GGUFParameter
|
||||
# so that we preserve quant_type information
|
||||
quant_type = None
|
||||
for arg in args:
|
||||
if isinstance(arg, list) and isinstance(arg[0], GGUFParameter):
|
||||
quant_type = arg[0].quant_type
|
||||
break
|
||||
if isinstance(arg, GGUFParameter):
|
||||
quant_type = arg.quant_type
|
||||
break
|
||||
if isinstance(result, torch.Tensor):
|
||||
quant_type = cls._extract_quant_type(args)
|
||||
return cls(result, quant_type=quant_type)
|
||||
# Handle tuples and lists
|
||||
elif isinstance(result, (tuple, list)):
|
||||
elif type(result) in (list, tuple):
|
||||
# Preserve the original type (tuple or list)
|
||||
quant_type = cls._extract_quant_type(args)
|
||||
wrapped = [cls(x, quant_type=quant_type) if isinstance(x, torch.Tensor) else x for x in result]
|
||||
return type(result)(wrapped)
|
||||
else:
|
||||
|
||||
@@ -75,7 +75,7 @@ class QuantizationConfigMixin:
|
||||
Args:
|
||||
config_dict (`Dict[str, Any]`):
|
||||
Dictionary that will be used to instantiate the configuration object.
|
||||
return_unused_kwargs (`bool`,*optional*, defaults to `False`):
|
||||
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return a list of unused keyword arguments. Used for `from_pretrained` method in
|
||||
`PreTrainedModel`.
|
||||
kwargs (`Dict[str, Any]`):
|
||||
|
||||
@@ -144,7 +144,7 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs
|
||||
def precondition_inputs(self, sample, sigma):
|
||||
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
|
||||
c_in = self._get_conditioning_c_in(sigma)
|
||||
scaled_sample = sample * c_in
|
||||
return scaled_sample
|
||||
|
||||
@@ -568,5 +568,10 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
noisy_samples = original_samples + noise * sigma
|
||||
return noisy_samples
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._get_conditioning_c_in
|
||||
def _get_conditioning_c_in(self, sigma):
|
||||
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
|
||||
return c_in
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
@@ -176,7 +176,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs
|
||||
def precondition_inputs(self, sample, sigma):
|
||||
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
|
||||
c_in = self._get_conditioning_c_in(sigma)
|
||||
scaled_sample = sample * c_in
|
||||
return scaled_sample
|
||||
|
||||
@@ -703,5 +703,10 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
noisy_samples = original_samples + noise * sigma
|
||||
return noisy_samples
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._get_conditioning_c_in
|
||||
def _get_conditioning_c_in(self, sigma):
|
||||
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
|
||||
return c_in
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
@@ -103,11 +103,13 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
|
||||
sigmas = torch.arange(num_train_timesteps + 1) / num_train_timesteps
|
||||
sigmas_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
||||
sigmas = torch.arange(num_train_timesteps + 1, dtype=sigmas_dtype) / num_train_timesteps
|
||||
if sigma_schedule == "karras":
|
||||
sigmas = self._compute_karras_sigmas(sigmas)
|
||||
elif sigma_schedule == "exponential":
|
||||
sigmas = self._compute_exponential_sigmas(sigmas)
|
||||
sigmas = sigmas.to(torch.float32)
|
||||
|
||||
self.timesteps = self.precondition_noise(sigmas)
|
||||
|
||||
@@ -159,7 +161,7 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
self._begin_index = begin_index
|
||||
|
||||
def precondition_inputs(self, sample, sigma):
|
||||
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
|
||||
c_in = self._get_conditioning_c_in(sigma)
|
||||
scaled_sample = sample * c_in
|
||||
return scaled_sample
|
||||
|
||||
@@ -230,18 +232,19 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
sigmas_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
||||
if sigmas is None:
|
||||
sigmas = torch.linspace(0, 1, self.num_inference_steps)
|
||||
sigmas = torch.linspace(0, 1, self.num_inference_steps, dtype=sigmas_dtype)
|
||||
elif isinstance(sigmas, float):
|
||||
sigmas = torch.tensor(sigmas, dtype=torch.float32)
|
||||
sigmas = torch.tensor(sigmas, dtype=sigmas_dtype)
|
||||
else:
|
||||
sigmas = sigmas
|
||||
sigmas = sigmas.to(sigmas_dtype)
|
||||
if self.config.sigma_schedule == "karras":
|
||||
sigmas = self._compute_karras_sigmas(sigmas)
|
||||
elif self.config.sigma_schedule == "exponential":
|
||||
sigmas = self._compute_exponential_sigmas(sigmas)
|
||||
|
||||
sigmas = sigmas.to(dtype=torch.float32, device=device)
|
||||
|
||||
self.timesteps = self.precondition_noise(sigmas)
|
||||
|
||||
if self.config.final_sigmas_type == "sigma_min":
|
||||
@@ -315,6 +318,7 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
s_noise: float = 1.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
pred_original_sample: Optional[torch.Tensor] = None,
|
||||
) -> Union[EDMEulerSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
||||
@@ -378,7 +382,8 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
||||
|
||||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||
pred_original_sample = self.precondition_outputs(sample, model_output, sigma_hat)
|
||||
if pred_original_sample is None:
|
||||
pred_original_sample = self.precondition_outputs(sample, model_output, sigma_hat)
|
||||
|
||||
# 2. Convert to an ODE derivative
|
||||
derivative = (sample - pred_original_sample) / sigma_hat
|
||||
@@ -435,5 +440,9 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
noisy_samples = original_samples + noise * sigma
|
||||
return noisy_samples
|
||||
|
||||
def _get_conditioning_c_in(self, sigma):
|
||||
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
|
||||
return c_in
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
@@ -62,9 +62,11 @@ from .import_utils import (
|
||||
get_objects_from_module,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_better_profanity_available,
|
||||
is_bitsandbytes_available,
|
||||
is_bitsandbytes_version,
|
||||
is_bs4_available,
|
||||
is_cosmos_guardrail_available,
|
||||
is_flax_available,
|
||||
is_ftfy_available,
|
||||
is_gguf_available,
|
||||
@@ -78,6 +80,7 @@ from .import_utils import (
|
||||
is_k_diffusion_version,
|
||||
is_librosa_available,
|
||||
is_matplotlib_available,
|
||||
is_nltk_available,
|
||||
is_note_seq_available,
|
||||
is_onnx_available,
|
||||
is_opencv_available,
|
||||
@@ -85,6 +88,7 @@ from .import_utils import (
|
||||
is_optimum_quanto_version,
|
||||
is_peft_available,
|
||||
is_peft_version,
|
||||
is_pytorch_retinaface_available,
|
||||
is_safetensors_available,
|
||||
is_scipy_available,
|
||||
is_sentencepiece_available,
|
||||
|
||||
@@ -160,6 +160,21 @@ class AutoencoderKLCogVideoX(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLCosmos(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLHunyuanVideo(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -430,6 +445,21 @@ class ControlNetXSAdapter(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class CosmosTransformer3DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class DiTTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -565,6 +595,21 @@ class HunyuanDiT2DMultiControlNetModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class HunyuanVideoFramepackTransformer3DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class HunyuanVideoTransformer3DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -392,6 +392,51 @@ class CogView4Pipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class ConsisIDPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class CosmosTextToWorldPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class CosmosVideoToWorldPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class CycleDiffusionPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
@@ -692,6 +737,21 @@ class HunyuanSkyreelsImageToVideoPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class HunyuanVideoFramepackPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class HunyuanVideoImageToVideoPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
@@ -1247,6 +1307,21 @@ class LTXImageToVideoPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LTXLatentUpsamplePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LTXPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
@@ -2732,6 +2807,36 @@ class VideoToVideoSDPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class VisualClozeGenerationPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class VisualClozePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class VQDiffusionPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -215,6 +215,10 @@ _gguf_available, _gguf_version = _is_package_available("gguf")
|
||||
_torchao_available, _torchao_version = _is_package_available("torchao")
|
||||
_bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes")
|
||||
_optimum_quanto_available, _optimum_quanto_version = _is_package_available("optimum", get_dist_name=True)
|
||||
_pytorch_retinaface_available, _pytorch_retinaface_version = _is_package_available("pytorch_retinaface")
|
||||
_better_profanity_available, _better_profanity_version = _is_package_available("better_profanity")
|
||||
_nltk_available, _nltk_version = _is_package_available("nltk")
|
||||
_cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("cosmos_guardrail")
|
||||
|
||||
|
||||
def is_torch_available():
|
||||
@@ -353,6 +357,22 @@ def is_timm_available():
|
||||
return _timm_available
|
||||
|
||||
|
||||
def is_pytorch_retinaface_available():
|
||||
return _pytorch_retinaface_available
|
||||
|
||||
|
||||
def is_better_profanity_available():
|
||||
return _better_profanity_available
|
||||
|
||||
|
||||
def is_nltk_available():
|
||||
return _nltk_available
|
||||
|
||||
|
||||
def is_cosmos_guardrail_available():
|
||||
return _cosmos_guardrail_available
|
||||
|
||||
|
||||
def is_hpu_available():
|
||||
return all(importlib.util.find_spec(lib) for lib in ("habana_frameworks", "habana_frameworks.torch"))
|
||||
|
||||
@@ -505,6 +525,22 @@ QUANTO_IMPORT_ERROR = """
|
||||
install optimum-quanto`
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
PYTORCH_RETINAFACE_IMPORT_ERROR = """
|
||||
{0} requires the pytorch_retinaface library but it was not found in your environment. You can install it with pip: `pip install pytorch_retinaface`
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
BETTER_PROFANITY_IMPORT_ERROR = """
|
||||
{0} requires the better_profanity library but it was not found in your environment. You can install it with pip: `pip install better_profanity`
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
NLTK_IMPORT_ERROR = """
|
||||
{0} requires the nltk library but it was not found in your environment. You can install it with pip: `pip install nltk`
|
||||
"""
|
||||
|
||||
|
||||
BACKENDS_MAPPING = OrderedDict(
|
||||
[
|
||||
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
|
||||
@@ -533,6 +569,9 @@ BACKENDS_MAPPING = OrderedDict(
|
||||
("gguf", (is_gguf_available, GGUF_IMPORT_ERROR)),
|
||||
("torchao", (is_torchao_available, TORCHAO_IMPORT_ERROR)),
|
||||
("quanto", (is_optimum_quanto_available, QUANTO_IMPORT_ERROR)),
|
||||
("pytorch_retinaface", (is_pytorch_retinaface_available, PYTORCH_RETINAFACE_IMPORT_ERROR)),
|
||||
("better_profanity", (is_better_profanity_available, BETTER_PROFANITY_IMPORT_ERROR)),
|
||||
("nltk", (is_nltk_available, NLTK_IMPORT_ERROR)),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ from .import_utils import (
|
||||
is_note_seq_available,
|
||||
is_onnx_available,
|
||||
is_opencv_available,
|
||||
is_optimum_quanto_available,
|
||||
is_peft_available,
|
||||
is_timm_available,
|
||||
is_torch_available,
|
||||
@@ -486,6 +487,13 @@ def require_bitsandbytes(test_case):
|
||||
return unittest.skipUnless(is_bitsandbytes_available(), "test requires bitsandbytes")(test_case)
|
||||
|
||||
|
||||
def require_quanto(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires quanto. These tests are skipped when quanto isn't installed.
|
||||
"""
|
||||
return unittest.skipUnless(is_optimum_quanto_available(), "test requires quanto")(test_case)
|
||||
|
||||
|
||||
def require_accelerate(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.
|
||||
|
||||
@@ -31,13 +31,14 @@ from diffusers import FlowMatchEulerDiscreteScheduler, FluxControlPipeline, Flux
|
||||
from diffusers.utils import load_image, logging
|
||||
from diffusers.utils.testing_utils import (
|
||||
CaptureLogger,
|
||||
backend_empty_cache,
|
||||
floats_tensor,
|
||||
is_peft_available,
|
||||
nightly,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_big_gpu_with_torch_cuda,
|
||||
require_big_accelerator,
|
||||
require_peft_backend,
|
||||
require_torch_gpu,
|
||||
require_torch_accelerator,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@@ -809,10 +810,10 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
@slow
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
@require_peft_backend
|
||||
@require_big_gpu_with_torch_cuda
|
||||
@pytest.mark.big_gpu_with_torch_cuda
|
||||
@require_big_accelerator
|
||||
@pytest.mark.big_accelerator
|
||||
class FluxLoRAIntegrationTests(unittest.TestCase):
|
||||
"""internal note: The integration slices were obtained on audace.
|
||||
|
||||
@@ -827,7 +828,7 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
|
||||
super().setUp()
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
self.pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
|
||||
|
||||
@@ -836,13 +837,13 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
|
||||
|
||||
del self.pipeline
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_flux_the_last_ben(self):
|
||||
self.pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors")
|
||||
self.pipeline.fuse_lora()
|
||||
self.pipeline.unload_lora_weights()
|
||||
# Instead of calling `enable_model_cpu_offload()`, we do a cuda placement here because the CI
|
||||
# Instead of calling `enable_model_cpu_offload()`, we do a accelerator placement here because the CI
|
||||
# run supports it. We have about 34GB RAM in the CI runner which kills the test when run with
|
||||
# `enable_model_cpu_offload()`. We repeat this for the other tests, too.
|
||||
self.pipeline = self.pipeline.to(torch_device)
|
||||
@@ -956,10 +957,10 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
@require_peft_backend
|
||||
@require_big_gpu_with_torch_cuda
|
||||
@pytest.mark.big_gpu_with_torch_cuda
|
||||
@require_big_accelerator
|
||||
@pytest.mark.big_accelerator
|
||||
class FluxControlLoRAIntegrationTests(unittest.TestCase):
|
||||
num_inference_steps = 10
|
||||
seed = 0
|
||||
@@ -969,17 +970,17 @@ class FluxControlLoRAIntegrationTests(unittest.TestCase):
|
||||
super().setUp()
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
self.pipeline = FluxControlPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
).to(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@parameterized.expand(["black-forest-labs/FLUX.1-Canny-dev-lora", "black-forest-labs/FLUX.1-Depth-dev-lora"])
|
||||
def test_lora(self, lora_ckpt_id):
|
||||
|
||||
@@ -28,13 +28,16 @@ from diffusers import (
|
||||
HunyuanVideoTransformer3DModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
Expectations,
|
||||
backend_empty_cache,
|
||||
floats_tensor,
|
||||
nightly,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_big_gpu_with_torch_cuda,
|
||||
require_big_accelerator,
|
||||
require_peft_backend,
|
||||
require_torch_gpu,
|
||||
require_torch_accelerator,
|
||||
skip_mps,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
@@ -192,10 +195,10 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
@require_peft_backend
|
||||
@require_big_gpu_with_torch_cuda
|
||||
@pytest.mark.big_gpu_with_torch_cuda
|
||||
@require_big_accelerator
|
||||
@pytest.mark.big_accelerator
|
||||
class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
|
||||
"""internal note: The integration slices were obtained on DGX.
|
||||
|
||||
@@ -210,7 +213,7 @@ class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
|
||||
super().setUp()
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
model_id = "hunyuanvideo-community/HunyuanVideo"
|
||||
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
|
||||
@@ -218,13 +221,13 @@ class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
|
||||
)
|
||||
self.pipeline = HunyuanVideoPipeline.from_pretrained(
|
||||
model_id, transformer=transformer, torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
).to(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_original_format_cseti(self):
|
||||
self.pipeline.load_lora_weights(
|
||||
@@ -249,8 +252,13 @@ class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
|
||||
out_slice = np.concatenate((out[:8], out[-8:]))
|
||||
|
||||
# fmt: off
|
||||
expected_slice = np.array([0.1013, 0.1924, 0.0078, 0.1021, 0.1929, 0.0078, 0.1023, 0.1919, 0.7402, 0.104, 0.4482, 0.7354, 0.0925, 0.4382, 0.7275, 0.0815])
|
||||
expected_slices = Expectations(
|
||||
{
|
||||
("cuda", 7): np.array([0.1013, 0.1924, 0.0078, 0.1021, 0.1929, 0.0078, 0.1023, 0.1919, 0.7402, 0.104, 0.4482, 0.7354, 0.0925, 0.4382, 0.7275, 0.0815]),
|
||||
}
|
||||
)
|
||||
# fmt: on
|
||||
expected_slice = expected_slices.get_expectation()
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
|
||||
|
||||
|
||||
@@ -93,12 +93,12 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
# Keeping this test here makes sense because it doesn't look any integration
|
||||
# (value assertions on logits).
|
||||
|
||||
@@ -34,7 +34,7 @@ from diffusers.utils.testing_utils import (
|
||||
is_flaky,
|
||||
nightly,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_big_gpu_with_torch_cuda,
|
||||
require_big_accelerator,
|
||||
require_peft_backend,
|
||||
require_torch_accelerator,
|
||||
torch_device,
|
||||
@@ -138,8 +138,8 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
@nightly
|
||||
@require_torch_accelerator
|
||||
@require_peft_backend
|
||||
@require_big_gpu_with_torch_cuda
|
||||
@pytest.mark.big_gpu_with_torch_cuda
|
||||
@require_big_accelerator
|
||||
@pytest.mark.big_accelerator
|
||||
class SD3LoraIntegrationTests(unittest.TestCase):
|
||||
pipeline_class = StableDiffusion3Img2ImgPipeline
|
||||
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"
|
||||
|
||||
@@ -37,12 +37,13 @@ from diffusers.utils import logging
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
CaptureLogger,
|
||||
backend_empty_cache,
|
||||
is_flaky,
|
||||
load_image,
|
||||
nightly,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_peft_backend,
|
||||
require_torch_gpu,
|
||||
require_torch_accelerator,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@@ -105,12 +106,12 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@is_flaky
|
||||
def test_multiple_wrong_adapter_name_raises_error(self):
|
||||
@@ -119,18 +120,18 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
|
||||
@slow
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
@require_peft_backend
|
||||
class LoraSDXLIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_sdxl_1_0_lora(self):
|
||||
generator = torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from diffusers import AutoencoderKLCosmos
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLCosmosTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKLCosmos
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
def get_autoencoder_kl_cosmos_config(self):
|
||||
return {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 4,
|
||||
"encoder_block_out_channels": (8, 8, 8, 8),
|
||||
"decode_block_out_channels": (8, 8, 8, 8),
|
||||
"attention_resolutions": (8,),
|
||||
"resolution": 64,
|
||||
"num_layers": 2,
|
||||
"patch_size": 4,
|
||||
"patch_type": "haar",
|
||||
"scaling_factor": 1.0,
|
||||
"spatial_compression_ratio": 4,
|
||||
"temporal_compression_ratio": 4,
|
||||
}
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 2
|
||||
num_frames = 9
|
||||
num_channels = 3
|
||||
height = 32
|
||||
width = 32
|
||||
|
||||
image = floats_tensor((batch_size, num_channels, num_frames, height, width)).to(torch_device)
|
||||
|
||||
return {"sample": image}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 9, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 9, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_autoencoder_kl_cosmos_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"CosmosEncoder3d",
|
||||
"CosmosDecoder3d",
|
||||
}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@unittest.skip("Not sure why this test fails. Investigate later.")
|
||||
def test_effective_gradient_checkpointing(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Unsupported test.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
@@ -62,7 +62,6 @@ from diffusers.utils.testing_utils import (
|
||||
backend_max_memory_allocated,
|
||||
backend_reset_peak_memory_stats,
|
||||
backend_synchronize,
|
||||
floats_tensor,
|
||||
get_python_version,
|
||||
is_torch_compile,
|
||||
numpy_cosine_similarity_distance,
|
||||
@@ -1581,6 +1580,34 @@ class ModelTesterMixin:
|
||||
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5))
|
||||
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5))
|
||||
|
||||
@parameterized.expand([(False, "block_level"), (True, "leaf_level")])
|
||||
@require_torch_accelerator
|
||||
@torch.no_grad()
|
||||
def test_group_offloading_with_layerwise_casting(self, record_stream, offload_type):
|
||||
torch.manual_seed(0)
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
if not getattr(model, "_supports_group_offloading", True):
|
||||
return
|
||||
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
_ = model(**inputs_dict)[0]
|
||||
|
||||
torch.manual_seed(0)
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
storage_dtype, compute_dtype = torch.float16, torch.float32
|
||||
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
|
||||
model = self.model_class(**init_dict)
|
||||
model.eval()
|
||||
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1}
|
||||
model.enable_group_offload(
|
||||
torch_device, offload_type=offload_type, use_stream=True, record_stream=record_stream, **additional_kwargs
|
||||
)
|
||||
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
|
||||
_ = model(**inputs_dict)[0]
|
||||
|
||||
def test_auto_model(self, expected_max_diff=5e-5):
|
||||
if self.forward_requires_fresh_args:
|
||||
model = self.model_class(**self.init_dict)
|
||||
@@ -1754,7 +1781,7 @@ class TorchCompileTesterMixin:
|
||||
@require_peft_backend
|
||||
@require_peft_version_greater("0.14.0")
|
||||
@is_torch_compile
|
||||
class TestLoraHotSwappingForModel(unittest.TestCase):
|
||||
class LoraHotSwappingForModelTesterMixin:
|
||||
"""Test that hotswapping does not result in recompilation on the model directly.
|
||||
|
||||
We're not extensively testing the hotswapping functionality since it is implemented in PEFT and is extensively
|
||||
@@ -1775,48 +1802,24 @@ class TestLoraHotSwappingForModel(unittest.TestCase):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def get_small_unet(self):
|
||||
# from diffusers UNet2DConditionModelTests
|
||||
torch.manual_seed(0)
|
||||
init_dict = {
|
||||
"block_out_channels": (4, 8),
|
||||
"norm_num_groups": 4,
|
||||
"down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"),
|
||||
"up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"),
|
||||
"cross_attention_dim": 8,
|
||||
"attention_head_dim": 2,
|
||||
"out_channels": 4,
|
||||
"in_channels": 4,
|
||||
"layers_per_block": 1,
|
||||
"sample_size": 16,
|
||||
}
|
||||
model = UNet2DConditionModel(**init_dict)
|
||||
return model.to(torch_device)
|
||||
|
||||
def get_unet_lora_config(self, lora_rank, lora_alpha, target_modules):
|
||||
def get_lora_config(self, lora_rank, lora_alpha, target_modules):
|
||||
# from diffusers test_models_unet_2d_condition.py
|
||||
from peft import LoraConfig
|
||||
|
||||
unet_lora_config = LoraConfig(
|
||||
lora_config = LoraConfig(
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=target_modules,
|
||||
init_lora_weights=False,
|
||||
use_dora=False,
|
||||
)
|
||||
return unet_lora_config
|
||||
return lora_config
|
||||
|
||||
def get_dummy_input(self):
|
||||
# from UNet2DConditionModelTests
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
sizes = (16, 16)
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
|
||||
|
||||
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
|
||||
def get_linear_module_name_other_than_attn(self, model):
|
||||
linear_names = [
|
||||
name for name, module in model.named_modules() if isinstance(module, nn.Linear) and "to_" not in name
|
||||
]
|
||||
return linear_names[0]
|
||||
|
||||
def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_modules1=None):
|
||||
"""
|
||||
@@ -1834,23 +1837,27 @@ class TestLoraHotSwappingForModel(unittest.TestCase):
|
||||
fine.
|
||||
"""
|
||||
# create 2 adapters with different ranks and alphas
|
||||
dummy_input = self.get_dummy_input()
|
||||
torch.manual_seed(0)
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
alpha0, alpha1 = rank0, rank1
|
||||
max_rank = max([rank0, rank1])
|
||||
if target_modules1 is None:
|
||||
target_modules1 = target_modules0[:]
|
||||
lora_config0 = self.get_unet_lora_config(rank0, alpha0, target_modules0)
|
||||
lora_config1 = self.get_unet_lora_config(rank1, alpha1, target_modules1)
|
||||
lora_config0 = self.get_lora_config(rank0, alpha0, target_modules0)
|
||||
lora_config1 = self.get_lora_config(rank1, alpha1, target_modules1)
|
||||
|
||||
unet = self.get_small_unet()
|
||||
unet.add_adapter(lora_config0, adapter_name="adapter0")
|
||||
model.add_adapter(lora_config0, adapter_name="adapter0")
|
||||
with torch.inference_mode():
|
||||
output0_before = unet(**dummy_input)["sample"]
|
||||
torch.manual_seed(0)
|
||||
output0_before = model(**inputs_dict)["sample"]
|
||||
|
||||
unet.add_adapter(lora_config1, adapter_name="adapter1")
|
||||
unet.set_adapter("adapter1")
|
||||
model.add_adapter(lora_config1, adapter_name="adapter1")
|
||||
model.set_adapter("adapter1")
|
||||
with torch.inference_mode():
|
||||
output1_before = unet(**dummy_input)["sample"]
|
||||
torch.manual_seed(0)
|
||||
output1_before = model(**inputs_dict)["sample"]
|
||||
|
||||
# sanity checks:
|
||||
tol = 5e-3
|
||||
@@ -1860,40 +1867,43 @@ class TestLoraHotSwappingForModel(unittest.TestCase):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
# save the adapter checkpoints
|
||||
unet.save_lora_adapter(os.path.join(tmp_dirname, "0"), safe_serialization=True, adapter_name="adapter0")
|
||||
unet.save_lora_adapter(os.path.join(tmp_dirname, "1"), safe_serialization=True, adapter_name="adapter1")
|
||||
del unet
|
||||
model.save_lora_adapter(os.path.join(tmp_dirname, "0"), safe_serialization=True, adapter_name="adapter0")
|
||||
model.save_lora_adapter(os.path.join(tmp_dirname, "1"), safe_serialization=True, adapter_name="adapter1")
|
||||
del model
|
||||
|
||||
# load the first adapter
|
||||
unet = self.get_small_unet()
|
||||
torch.manual_seed(0)
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
if do_compile or (rank0 != rank1):
|
||||
# no need to prepare if the model is not compiled or if the ranks are identical
|
||||
unet.enable_lora_hotswap(target_rank=max_rank)
|
||||
model.enable_lora_hotswap(target_rank=max_rank)
|
||||
|
||||
file_name0 = os.path.join(os.path.join(tmp_dirname, "0"), "pytorch_lora_weights.safetensors")
|
||||
file_name1 = os.path.join(os.path.join(tmp_dirname, "1"), "pytorch_lora_weights.safetensors")
|
||||
unet.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)
|
||||
model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)
|
||||
|
||||
if do_compile:
|
||||
unet = torch.compile(unet, mode="reduce-overhead")
|
||||
model = torch.compile(model, mode="reduce-overhead")
|
||||
|
||||
with torch.inference_mode():
|
||||
output0_after = unet(**dummy_input)["sample"]
|
||||
output0_after = model(**inputs_dict)["sample"]
|
||||
assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol)
|
||||
|
||||
# hotswap the 2nd adapter
|
||||
unet.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None)
|
||||
model.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None)
|
||||
|
||||
# we need to call forward to potentially trigger recompilation
|
||||
with torch.inference_mode():
|
||||
output1_after = unet(**dummy_input)["sample"]
|
||||
output1_after = model(**inputs_dict)["sample"]
|
||||
assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol)
|
||||
|
||||
# check error when not passing valid adapter name
|
||||
name = "does-not-exist"
|
||||
msg = f"Trying to hotswap LoRA adapter '{name}' but there is no existing adapter by that name"
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
unet.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None)
|
||||
model.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None)
|
||||
|
||||
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
|
||||
def test_hotswapping_model(self, rank0, rank1):
|
||||
@@ -1910,6 +1920,9 @@ class TestLoraHotSwappingForModel(unittest.TestCase):
|
||||
|
||||
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
|
||||
def test_hotswapping_compiled_model_conv2d(self, rank0, rank1):
|
||||
if "unet" not in self.model_class.__name__.lower():
|
||||
return
|
||||
|
||||
# It's important to add this context to raise an error on recompilation
|
||||
target_modules = ["conv", "conv1", "conv2"]
|
||||
with torch._dynamo.config.patch(error_on_recompile=True):
|
||||
@@ -1917,52 +1930,77 @@ class TestLoraHotSwappingForModel(unittest.TestCase):
|
||||
|
||||
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
|
||||
def test_hotswapping_compiled_model_both_linear_and_conv2d(self, rank0, rank1):
|
||||
if "unet" not in self.model_class.__name__.lower():
|
||||
return
|
||||
|
||||
# It's important to add this context to raise an error on recompilation
|
||||
target_modules = ["to_q", "conv"]
|
||||
with torch._dynamo.config.patch(error_on_recompile=True):
|
||||
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
|
||||
|
||||
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
|
||||
def test_hotswapping_compiled_model_both_linear_and_other(self, rank0, rank1):
|
||||
# In `test_hotswapping_compiled_model_both_linear_and_conv2d()`, we check if we can do hotswapping
|
||||
# with `torch.compile()` for models that have both linear and conv layers. In this test, we check
|
||||
# if we can target a linear layer from the transformer blocks and another linear layer from non-attention
|
||||
# block.
|
||||
target_modules = ["to_q"]
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
target_modules.append(self.get_linear_module_name_other_than_attn(model))
|
||||
del model
|
||||
|
||||
# It's important to add this context to raise an error on recompilation
|
||||
with torch._dynamo.config.patch(error_on_recompile=True):
|
||||
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
|
||||
|
||||
def test_enable_lora_hotswap_called_after_adapter_added_raises(self):
|
||||
# ensure that enable_lora_hotswap is called before loading the first adapter
|
||||
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
|
||||
unet = self.get_small_unet()
|
||||
unet.add_adapter(lora_config)
|
||||
lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.add_adapter(lora_config)
|
||||
|
||||
msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.")
|
||||
with self.assertRaisesRegex(RuntimeError, msg):
|
||||
unet.enable_lora_hotswap(target_rank=32)
|
||||
model.enable_lora_hotswap(target_rank=32)
|
||||
|
||||
def test_enable_lora_hotswap_called_after_adapter_added_warning(self):
|
||||
# ensure that enable_lora_hotswap is called before loading the first adapter
|
||||
from diffusers.loaders.peft import logger
|
||||
|
||||
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
|
||||
unet = self.get_small_unet()
|
||||
unet.add_adapter(lora_config)
|
||||
lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.add_adapter(lora_config)
|
||||
msg = (
|
||||
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
|
||||
)
|
||||
with self.assertLogs(logger=logger, level="WARNING") as cm:
|
||||
unet.enable_lora_hotswap(target_rank=32, check_compiled="warn")
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
|
||||
assert any(msg in log for log in cm.output)
|
||||
|
||||
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self):
|
||||
# check possibility to ignore the error/warning
|
||||
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
|
||||
unet = self.get_small_unet()
|
||||
unet.add_adapter(lora_config)
|
||||
lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.add_adapter(lora_config)
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always") # Capture all warnings
|
||||
unet.enable_lora_hotswap(target_rank=32, check_compiled="warn")
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
|
||||
self.assertEqual(len(w), 0, f"Expected no warnings, but got: {[str(warn.message) for warn in w]}")
|
||||
|
||||
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
|
||||
# check that wrong argument value raises an error
|
||||
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
|
||||
unet = self.get_small_unet()
|
||||
unet.add_adapter(lora_config)
|
||||
lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.add_adapter(lora_config)
|
||||
msg = re.escape("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead.")
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
unet.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument")
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument")
|
||||
|
||||
def test_hotswap_second_adapter_targets_more_layers_raises(self):
|
||||
# check the error and log
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user