Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c73c00610e |
@@ -359,8 +359,6 @@ jobs:
|
||||
test_location: "bnb"
|
||||
- backend: "gguf"
|
||||
test_location: "gguf"
|
||||
- backend: "torchao"
|
||||
test_location: "torchao"
|
||||
runs-on:
|
||||
group: aws-g6e-xlarge-plus
|
||||
container:
|
||||
|
||||
@@ -68,7 +68,7 @@ jobs:
|
||||
- name: Test installing diffusers and importing
|
||||
run: |
|
||||
pip install diffusers && pip uninstall diffusers -y
|
||||
pip install -i https://test.pypi.org/simple/ diffusers
|
||||
pip install -i https://testpypi.python.org/pypi diffusers
|
||||
python -c "from diffusers import __version__; print(__version__)"
|
||||
python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('fusing/unet-ldm-dummy-update'); pipe()"
|
||||
python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('hf-internal-testing/tiny-stable-diffusion-pipe', safety_checker=None); pipe('ah suh du')"
|
||||
|
||||
@@ -48,7 +48,7 @@
|
||||
- local: using-diffusers/inpaint
|
||||
title: Inpainting
|
||||
- local: using-diffusers/text-img2vid
|
||||
title: Video generation
|
||||
title: Text or image-to-video
|
||||
- local: using-diffusers/depth2img
|
||||
title: Depth-to-image
|
||||
title: Generative tasks
|
||||
@@ -429,7 +429,7 @@
|
||||
- local: api/pipelines/ledits_pp
|
||||
title: LEDITS++
|
||||
- local: api/pipelines/ltx_video
|
||||
title: LTXVideo
|
||||
title: LTX
|
||||
- local: api/pipelines/lumina
|
||||
title: Lumina-T2X
|
||||
- local: api/pipelines/marigold
|
||||
|
||||
@@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
|
||||
```python
|
||||
from diffusers import AutoencoderKLHunyuanVideo
|
||||
|
||||
vae = AutoencoderKLHunyuanVideo.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder="vae", torch_dtype=torch.float16)
|
||||
vae = AutoencoderKLHunyuanVideo.from_pretrained("tencent/HunyuanVideo", torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
## AutoencoderKLHunyuanVideo
|
||||
|
||||
@@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
|
||||
```python
|
||||
from diffusers import AutoencoderKLLTXVideo
|
||||
|
||||
vae = AutoencoderKLLTXVideo.from_pretrained("Lightricks/LTX-Video", subfolder="vae", torch_dtype=torch.float32).to("cuda")
|
||||
vae = AutoencoderKLLTXVideo.from_pretrained("TODO/TODO", subfolder="vae", torch_dtype=torch.float32).to("cuda")
|
||||
```
|
||||
|
||||
## AutoencoderKLLTXVideo
|
||||
|
||||
@@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
|
||||
```python
|
||||
from diffusers import HunyuanVideoTransformer3DModel
|
||||
|
||||
transformer = HunyuanVideoTransformer3DModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
transformer = HunyuanVideoTransformer3DModel.from_pretrained("tencent/HunyuanVideo", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## HunyuanVideoTransformer3DModel
|
||||
|
||||
@@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
|
||||
```python
|
||||
from diffusers import LTXVideoTransformer3DModel
|
||||
|
||||
transformer = LTXVideoTransformer3DModel.from_pretrained("Lightricks/LTX-Video", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
|
||||
transformer = LTXVideoTransformer3DModel.from_pretrained("TODO/TODO", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
|
||||
```
|
||||
|
||||
## LTXVideoTransformer3DModel
|
||||
|
||||
@@ -22,7 +22,7 @@ The model can be loaded with the following code snippet.
|
||||
```python
|
||||
from diffusers import SanaTransformer2DModel
|
||||
|
||||
transformer = SanaTransformer2DModel.from_pretrained("Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
transformer = SanaTransformer2DModel.from_pretrained("Efficient-Large-Model/Sana_1600M_1024px_diffusers", subfolder="transformer", torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
## SanaTransformer2DModel
|
||||
|
||||
@@ -19,55 +19,10 @@ The abstract from the paper is:
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Quantization
|
||||
|
||||
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
|
||||
|
||||
Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`AllegroPipeline`] for inference with bitsandbytes.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, AllegroTransformer3DModel, AllegroPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel
|
||||
|
||||
quant_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
text_encoder_8bit = T5EncoderModel.from_pretrained(
|
||||
"rhymes-ai/Allegro",
|
||||
subfolder="text_encoder",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
|
||||
transformer_8bit = AllegroTransformer3DModel.from_pretrained(
|
||||
"rhymes-ai/Allegro",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
pipeline = AllegroPipeline.from_pretrained(
|
||||
"rhymes-ai/Allegro",
|
||||
text_encoder=text_encoder_8bit,
|
||||
transformer=transformer_8bit,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="balanced",
|
||||
)
|
||||
|
||||
prompt = (
|
||||
"A seaside harbor with bright sunlight and sparkling seawater, with many boats in the water. From an aerial view, "
|
||||
"the boats vary in size and color, some moving and some stationary. Fishing boats in the water suggest that this "
|
||||
"location might be a popular spot for docking fishing boats."
|
||||
)
|
||||
video = pipeline(prompt, guidance_scale=7.5, max_sequence_length=512).frames[0]
|
||||
export_to_video(video, "harbor.mp4", fps=15)
|
||||
```
|
||||
|
||||
## AllegroPipeline
|
||||
|
||||
[[autodoc]] AllegroPipeline
|
||||
|
||||
@@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# AuraFlow
|
||||
|
||||
AuraFlow is inspired by [Stable Diffusion 3](../pipelines/stable_diffusion/stable_diffusion_3) and is by far the largest text-to-image generation model that comes with an Apache 2.0 license. This model achieves state-of-the-art results on the [GenEval](https://github.com/djghosh13/geneval) benchmark.
|
||||
AuraFlow is inspired by [Stable Diffusion 3](../pipelines/stable_diffusion/stable_diffusion_3.md) and is by far the largest text-to-image generation model that comes with an Apache 2.0 license. This model achieves state-of-the-art results on the [GenEval](https://github.com/djghosh13/geneval) benchmark.
|
||||
|
||||
It was developed by the Fal team and more details about it can be found in [this blog post](https://blog.fal.ai/auraflow/).
|
||||
|
||||
@@ -22,46 +22,6 @@ AuraFlow can be quite expensive to run on consumer hardware devices. However, yo
|
||||
|
||||
</Tip>
|
||||
|
||||
## Quantization
|
||||
|
||||
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
|
||||
|
||||
Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`AuraFlowPipeline`] for inference with bitsandbytes.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, AuraFlowTransformer2DModel, AuraFlowPipeline
|
||||
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel
|
||||
|
||||
quant_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
text_encoder_8bit = T5EncoderModel.from_pretrained(
|
||||
"fal/AuraFlow",
|
||||
subfolder="text_encoder",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
|
||||
transformer_8bit = AuraFlowTransformer2DModel.from_pretrained(
|
||||
"fal/AuraFlow",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
pipeline = AuraFlowPipeline.from_pretrained(
|
||||
"fal/AuraFlow",
|
||||
text_encoder=text_encoder_8bit,
|
||||
transformer=transformer_8bit,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="balanced",
|
||||
)
|
||||
|
||||
prompt = "a tiny astronaut hatching from an egg on the moon"
|
||||
image = pipeline(prompt).images[0]
|
||||
image.save("auraflow.png")
|
||||
```
|
||||
|
||||
## AuraFlowPipeline
|
||||
|
||||
[[autodoc]] AuraFlowPipeline
|
||||
|
||||
@@ -23,7 +23,7 @@ The abstract from the paper is:
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
@@ -112,46 +112,13 @@ CogVideoX-2b requires about 19 GB of GPU memory to decode 49 frames (6 seconds o
|
||||
- With enabling cpu offloading and tiling, memory usage is `11 GB`
|
||||
- `pipe.vae.enable_slicing()`
|
||||
|
||||
## Quantization
|
||||
### Quantized inference
|
||||
|
||||
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
|
||||
[torchao](https://github.com/pytorch/ao) and [optimum-quanto](https://github.com/huggingface/optimum-quanto/) can be used to quantize the text encoder, transformer and VAE modules to lower the memory requirements. This makes it possible to run the model on a free-tier T4 Colab or lower VRAM GPUs!
|
||||
|
||||
Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`CogVideoXPipeline`] for inference with bitsandbytes.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, CogVideoXTransformer3DModel, CogVideoXPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel
|
||||
|
||||
quant_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
text_encoder_8bit = T5EncoderModel.from_pretrained(
|
||||
"THUDM/CogVideoX-2b",
|
||||
subfolder="text_encoder",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
|
||||
transformer_8bit = CogVideoXTransformer3DModel.from_pretrained(
|
||||
"THUDM/CogVideoX-2b",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
pipeline = CogVideoXPipeline.from_pretrained(
|
||||
"THUDM/CogVideoX-2b",
|
||||
text_encoder=text_encoder_8bit,
|
||||
transformer=transformer_8bit,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="balanced",
|
||||
)
|
||||
|
||||
prompt = "A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting."
|
||||
video = pipeline(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
|
||||
export_to_video(video, "ship.mp4", fps=8)
|
||||
```
|
||||
It is also worth noting that torchao quantization is fully compatible with [torch.compile](/optimization/torch2.0#torchcompile), which allows for much faster inference speed. Additionally, models can be serialized and stored in a quantized datatype to save disk space with torchao. Find examples and benchmarks in the gists below.
|
||||
- [torchao](https://gist.github.com/a-r-r-o-w/4d9732d17412888c885480c6521a9897)
|
||||
- [quanto](https://gist.github.com/a-r-r-o-w/31be62828b00a9292821b85c1017effa)
|
||||
|
||||
## CogVideoXPipeline
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ The abstract from the paper is:
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
@@ -334,46 +334,6 @@ out = pipe(
|
||||
out.save("image.png")
|
||||
```
|
||||
|
||||
## Quantization
|
||||
|
||||
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
|
||||
|
||||
Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`FluxPipeline`] for inference with bitsandbytes.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, FluxTransformer2DModel, FluxPipeline
|
||||
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel
|
||||
|
||||
quant_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
text_encoder_8bit = T5EncoderModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="text_encoder_2",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
|
||||
transformer_8bit = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
pipeline = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
text_encoder=text_encoder_8bit,
|
||||
transformer=transformer_8bit,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="balanced",
|
||||
)
|
||||
|
||||
prompt = "a tiny astronaut hatching from an egg on the moon"
|
||||
image = pipeline(prompt, guidance_scale=3.5, height=768, width=1360, num_inference_steps=50).images[0]
|
||||
image.save("flux.png")
|
||||
```
|
||||
|
||||
## Single File Loading for the `FluxTransformer2DModel`
|
||||
|
||||
The `FluxTransformer2DModel` supports loading checkpoints in the original format shipped by Black Forest Labs. This is also useful when trying to load finetunes or quantized versions of the models that have been published by the community.
|
||||
|
||||
@@ -20,7 +20,7 @@
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](https://huggingface.co/docs/diffusers/main/en/using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
@@ -29,40 +29,9 @@ Recommendations for inference:
|
||||
- Transformer should be in `torch.bfloat16`.
|
||||
- VAE should be in `torch.float16`.
|
||||
- `num_frames` should be of the form `4 * k + 1`, for example `49` or `129`.
|
||||
- For smaller resolution videos, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution images, try higher values (between `7.0` and `12.0`). The default value is `7.0` for HunyuanVideo.
|
||||
- For smaller resolution images, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution images, try higher values (between `7.0` and `12.0`). The default value is `7.0` for HunyuanVideo.
|
||||
- For more information about supported resolutions and other details, please refer to the original repository [here](https://github.com/Tencent/HunyuanVideo/).
|
||||
|
||||
## Quantization
|
||||
|
||||
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
|
||||
|
||||
Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`HunyuanVideoPipeline`] for inference with bitsandbytes.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, HunyuanVideoTransformer3DModel, HunyuanVideoPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
|
||||
transformer_8bit = HunyuanVideoTransformer3DModel.from_pretrained(
|
||||
"tencent/HunyuanVideo",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
pipeline = HunyuanVideoPipeline.from_pretrained(
|
||||
"tencent/HunyuanVideo",
|
||||
transformer=transformer_8bit,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="balanced",
|
||||
)
|
||||
|
||||
prompt = "A cat walks on the grass, realistic style."
|
||||
video = pipeline(prompt=prompt, num_frames=61, num_inference_steps=30).frames[0]
|
||||
export_to_video(video, "cat.mp4", fps=15)
|
||||
```
|
||||
|
||||
## HunyuanVideoPipeline
|
||||
|
||||
[[autodoc]] HunyuanVideoPipeline
|
||||
|
||||
@@ -30,7 +30,7 @@ HunyuanDiT has the following components:
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](https://huggingface.co/docs/diffusers/main/en/using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ This pipeline was contributed by [maxin-cn](https://github.com/maxin-cn). The or
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
@@ -70,47 +70,6 @@ Without torch.compile(): Average inference time: 16.246 seconds.
|
||||
With torch.compile(): Average inference time: 14.573 seconds.
|
||||
```
|
||||
|
||||
## Quantization
|
||||
|
||||
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
|
||||
|
||||
Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`LattePipeline`] for inference with bitsandbytes.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, LatteTransformer3DModel, LattePipeline
|
||||
from diffusers.utils import export_to_gif
|
||||
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel
|
||||
|
||||
quant_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
text_encoder_8bit = T5EncoderModel.from_pretrained(
|
||||
"maxin-cn/Latte-1",
|
||||
subfolder="text_encoder",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
|
||||
transformer_8bit = LatteTransformer3DModel.from_pretrained(
|
||||
"maxin-cn/Latte-1",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
pipeline = LattePipeline.from_pretrained(
|
||||
"maxin-cn/Latte-1",
|
||||
text_encoder=text_encoder_8bit,
|
||||
transformer=transformer_8bit,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="balanced",
|
||||
)
|
||||
|
||||
prompt = "A small cactus with a happy face in the Sahara desert."
|
||||
video = pipeline(prompt).frames[0]
|
||||
export_to_gif(video, "latte.gif")
|
||||
```
|
||||
|
||||
## LattePipeline
|
||||
|
||||
[[autodoc]] LattePipeline
|
||||
|
||||
@@ -12,34 +12,24 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# LTX Video
|
||||
# LTX
|
||||
|
||||
[LTX Video](https://huggingface.co/Lightricks/LTX-Video) is the first DiT-based video generation model capable of generating high-quality videos in real-time. It produces 24 FPS videos at a 768x512 resolution faster than they can be watched. Trained on a large-scale dataset of diverse videos, the model generates high-resolution videos with realistic and varied content. We provide a model for both text-to-video as well as image + text-to-video usecases.
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
Available models:
|
||||
|
||||
| Model name | Recommended dtype |
|
||||
|:-------------:|:-----------------:|
|
||||
| [`LTX Video 0.9.0`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.safetensors) | `torch.bfloat16` |
|
||||
| [`LTX Video 0.9.1`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) | `torch.bfloat16` |
|
||||
|
||||
Note: The recommended dtype is for the transformer component. The VAE and text encoders can be either `torch.float32`, `torch.bfloat16` or `torch.float16` but the recommended dtype is `torch.bfloat16` as used in the original repository.
|
||||
|
||||
## Loading Single Files
|
||||
|
||||
Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`]. We recommend using `from_single_file` for the Lightricks series of models, as they plan to release multiple models in the future in the single file format.
|
||||
Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`].
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import AutoencoderKLLTXVideo, LTXImageToVideoPipeline, LTXVideoTransformer3DModel
|
||||
|
||||
# `single_file_url` could also be https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.1.safetensors
|
||||
single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.safetensors"
|
||||
transformer = LTXVideoTransformer3DModel.from_single_file(
|
||||
single_file_url, torch_dtype=torch.bfloat16
|
||||
@@ -109,77 +99,8 @@ export_to_video(video, "output_gguf_ltx.mp4", fps=24)
|
||||
|
||||
Make sure to read the [documentation on GGUF](../../quantization/gguf) to learn more about our GGUF support.
|
||||
|
||||
<!-- TODO(aryan): Update this when official weights are supported -->
|
||||
|
||||
Loading and running inference with [LTX Video 0.9.1](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) weights.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import LTXPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
pipe = LTXPipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.1-diffusers", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
|
||||
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
|
||||
|
||||
video = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
width=768,
|
||||
height=512,
|
||||
num_frames=161,
|
||||
decode_timestep=0.03,
|
||||
decode_noise_scale=0.025,
|
||||
num_inference_steps=50,
|
||||
).frames[0]
|
||||
export_to_video(video, "output.mp4", fps=24)
|
||||
```
|
||||
|
||||
Refer to [this section](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox#memory-optimization) to learn more about optimizing memory consumption.
|
||||
|
||||
## Quantization
|
||||
|
||||
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
|
||||
|
||||
Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`LTXPipeline`] for inference with bitsandbytes.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, LTXVideoTransformer3DModel, LTXPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel
|
||||
|
||||
quant_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
text_encoder_8bit = T5EncoderModel.from_pretrained(
|
||||
"Lightricks/LTX-Video",
|
||||
subfolder="text_encoder",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
|
||||
transformer_8bit = LTXVideoTransformer3DModel.from_pretrained(
|
||||
"Lightricks/LTX-Video",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
pipeline = LTXPipeline.from_pretrained(
|
||||
"Lightricks/LTX-Video",
|
||||
text_encoder=text_encoder_8bit,
|
||||
transformer=transformer_8bit,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="balanced",
|
||||
)
|
||||
|
||||
prompt = "A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting."
|
||||
video = pipeline(prompt=prompt, num_frames=161, num_inference_steps=50).frames[0]
|
||||
export_to_video(video, "ship.mp4", fps=24)
|
||||
```
|
||||
|
||||
## LTXPipeline
|
||||
|
||||
[[autodoc]] LTXPipeline
|
||||
|
||||
@@ -47,7 +47,7 @@ This pipeline was contributed by [PommesPeter](https://github.com/PommesPeter).
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
@@ -82,46 +82,6 @@ pipeline.vae.decode = torch.compile(pipeline.vae.decode, mode="max-autotune", fu
|
||||
image = pipeline(prompt="Upper body of a young woman in a Victorian-era outfit with brass goggles and leather straps. Background shows an industrial revolution cityscape with smoky skies and tall, metal structures").images[0]
|
||||
```
|
||||
|
||||
## Quantization
|
||||
|
||||
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
|
||||
|
||||
Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`LuminaText2ImgPipeline`] for inference with bitsandbytes.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, Transformer2DModel, LuminaText2ImgPipeline
|
||||
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel
|
||||
|
||||
quant_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
text_encoder_8bit = T5EncoderModel.from_pretrained(
|
||||
"Alpha-VLLM/Lumina-Next-SFT-diffusers",
|
||||
subfolder="text_encoder",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
|
||||
transformer_8bit = Transformer2DModel.from_pretrained(
|
||||
"Alpha-VLLM/Lumina-Next-SFT-diffusers",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
pipeline = LuminaText2ImgPipeline.from_pretrained(
|
||||
"Alpha-VLLM/Lumina-Next-SFT-diffusers",
|
||||
text_encoder=text_encoder_8bit,
|
||||
transformer=transformer_8bit,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="balanced",
|
||||
)
|
||||
|
||||
prompt = "a tiny astronaut hatching from an egg on the moon"
|
||||
image = pipeline(prompt).images[0]
|
||||
image.save("lumina.png")
|
||||
```
|
||||
|
||||
## LuminaText2ImgPipeline
|
||||
|
||||
[[autodoc]] LuminaText2ImgPipeline
|
||||
|
||||
@@ -15,59 +15,15 @@
|
||||
|
||||
# Mochi 1 Preview
|
||||
|
||||
> [!TIP]
|
||||
> Only a research preview of the model weights is available at the moment.
|
||||
|
||||
[Mochi 1](https://huggingface.co/genmo/mochi-1-preview) is a video generation model by Genmo with a strong focus on prompt adherence and motion quality. The model features a 10B parameter Asmmetric Diffusion Transformer (AsymmDiT) architecture, and uses non-square QKV and output projection layers to reduce inference memory requirements. A single T5-XXL model is used to encode prompts.
|
||||
[Mochi 1 Preview](https://huggingface.co/genmo/mochi-1-preview) from Genmo.
|
||||
|
||||
*Mochi 1 preview is an open state-of-the-art video generation model with high-fidelity motion and strong prompt adherence in preliminary evaluation. This model dramatically closes the gap between closed and open video generation systems. The model is released under a permissive Apache 2.0 license.*
|
||||
|
||||
> [!TIP]
|
||||
> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
<Tip>
|
||||
|
||||
## Quantization
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
|
||||
|
||||
Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`MochiPipeline`] for inference with bitsandbytes.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, MochiTransformer3DModel, MochiPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel
|
||||
|
||||
quant_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
text_encoder_8bit = T5EncoderModel.from_pretrained(
|
||||
"genmo/mochi-1-preview",
|
||||
subfolder="text_encoder",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
|
||||
transformer_8bit = MochiTransformer3DModel.from_pretrained(
|
||||
"genmo/mochi-1-preview",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
pipeline = MochiPipeline.from_pretrained(
|
||||
"genmo/mochi-1-preview",
|
||||
text_encoder=text_encoder_8bit,
|
||||
transformer=transformer_8bit,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="balanced",
|
||||
)
|
||||
|
||||
video = pipeline(
|
||||
"Close-up of a cats eye, with the galaxy reflected in the cats eye. Ultra high resolution 4k.",
|
||||
num_inference_steps=28,
|
||||
guidance_scale=3.5
|
||||
).frames[0]
|
||||
export_to_video(video, "cat.mp4")
|
||||
```
|
||||
</Tip>
|
||||
|
||||
## Generating videos with Mochi-1 Preview
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ Some notes about this pipeline:
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ The abstract from the paper is:
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
@@ -32,9 +32,9 @@ Available models:
|
||||
|
||||
| Model | Recommended dtype |
|
||||
|:-----:|:-----------------:|
|
||||
| [`Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers) | `torch.bfloat16` |
|
||||
| [`Efficient-Large-Model/Sana_1600M_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_diffusers) | `torch.float16` |
|
||||
| [`Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers) | `torch.float16` |
|
||||
| [`Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers) | `torch.bfloat16` |
|
||||
| [`Efficient-Large-Model/Sana_1600M_512px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_diffusers) | `torch.float16` |
|
||||
| [`Efficient-Large-Model/Sana_1600M_512px_MultiLing_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_MultiLing_diffusers) | `torch.float16` |
|
||||
| [`Efficient-Large-Model/Sana_600M_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_600M_1024px_diffusers) | `torch.float16` |
|
||||
@@ -50,46 +50,6 @@ Make sure to pass the `variant` argument for downloaded checkpoints to use lower
|
||||
|
||||
</Tip>
|
||||
|
||||
## Quantization
|
||||
|
||||
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
|
||||
|
||||
Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`SanaPipeline`] for inference with bitsandbytes.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, SanaTransformer2DModel, SanaPipeline
|
||||
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, AutoModelForCausalLM
|
||||
|
||||
quant_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
text_encoder_8bit = AutoModelForCausalLM.from_pretrained(
|
||||
"Efficient-Large-Model/Sana_1600M_1024px_diffusers",
|
||||
subfolder="text_encoder",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
|
||||
transformer_8bit = SanaTransformer2DModel.from_pretrained(
|
||||
"Efficient-Large-Model/Sana_1600M_1024px_diffusers",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
pipeline = SanaPipeline.from_pretrained(
|
||||
"Efficient-Large-Model/Sana_1600M_1024px_diffusers",
|
||||
text_encoder=text_encoder_8bit,
|
||||
transformer=transformer_8bit,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="balanced",
|
||||
)
|
||||
|
||||
prompt = "a tiny astronaut hatching from an egg on the moon"
|
||||
image = pipeline(prompt).images[0]
|
||||
image.save("sana.png")
|
||||
```
|
||||
|
||||
## SanaPipeline
|
||||
|
||||
[[autodoc]] SanaPipeline
|
||||
|
||||
@@ -35,57 +35,6 @@ During inference:
|
||||
* The _quality_ of the generated audio sample can be controlled by the `num_inference_steps` argument; higher steps give higher quality audio at the expense of slower inference.
|
||||
* Multiple waveforms can be generated in one go: set `num_waveforms_per_prompt` to a value greater than 1 to enable. Automatic scoring will be performed between the generated waveforms and prompt text, and the audios ranked from best to worst accordingly.
|
||||
|
||||
## Quantization
|
||||
|
||||
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
|
||||
|
||||
Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`StableAudioPipeline`] for inference with bitsandbytes.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, StableAudioDiTModel, StableAudioPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel
|
||||
|
||||
quant_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
text_encoder_8bit = T5EncoderModel.from_pretrained(
|
||||
"stabilityai/stable-audio-open-1.0",
|
||||
subfolder="text_encoder",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
|
||||
transformer_8bit = StableAudioDiTModel.from_pretrained(
|
||||
"stabilityai/stable-audio-open-1.0",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
pipeline = StableAudioPipeline.from_pretrained(
|
||||
"stabilityai/stable-audio-open-1.0",
|
||||
text_encoder=text_encoder_8bit,
|
||||
transformer=transformer_8bit,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="balanced",
|
||||
)
|
||||
|
||||
prompt = "The sound of a hammer hitting a wooden surface."
|
||||
negative_prompt = "Low quality."
|
||||
audio = pipeline(
|
||||
prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=200,
|
||||
audio_end_in_s=10.0,
|
||||
num_waveforms_per_prompt=3,
|
||||
generator=generator,
|
||||
).audios
|
||||
|
||||
output = audio[0].T.float().cpu().numpy()
|
||||
sf.write("hammer.wav", output, pipeline.vae.sampling_rate)
|
||||
```
|
||||
|
||||
|
||||
## StableAudioPipeline
|
||||
[[autodoc]] StableAudioPipeline
|
||||
|
||||
@@ -268,46 +268,6 @@ image.save("sd3_hello_world.png")
|
||||
|
||||
Check out the full script [here](https://gist.github.com/sayakpaul/508d89d7aad4f454900813da5d42ca97).
|
||||
|
||||
## Quantization
|
||||
|
||||
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
|
||||
|
||||
Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`StableDiffusion3Pipeline`] for inference with bitsandbytes.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, SD3Transformer2DModel, StableDiffusion3Pipeline
|
||||
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel
|
||||
|
||||
quant_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
text_encoder_8bit = T5EncoderModel.from_pretrained(
|
||||
"stabilityai/stable-diffusion-3.5-large",
|
||||
subfolder="text_encoder_3",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
|
||||
transformer_8bit = SD3Transformer2DModel.from_pretrained(
|
||||
"stabilityai/stable-diffusion-3.5-large",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
pipeline = StableDiffusion3Pipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-3.5-large",
|
||||
text_encoder=text_encoder_8bit,
|
||||
transformer=transformer_8bit,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="balanced",
|
||||
)
|
||||
|
||||
prompt = "a tiny astronaut hatching from an egg on the moon"
|
||||
image = pipeline(prompt, num_inference_steps=28, guidance_scale=7.0).images[0]
|
||||
image.save("sd3.png")
|
||||
```
|
||||
|
||||
## Using Long Prompts with the T5 Text Encoder
|
||||
|
||||
By default, the T5 Text Encoder prompt uses a maximum sequence length of `256`. This can be adjusted by setting the `max_sequence_length` to accept fewer or more tokens. Keep in mind that longer sequences require additional resources and result in longer generation times, such as during batch inference.
|
||||
|
||||
@@ -79,8 +79,4 @@ Happy exploring, and thank you for being part of the Diffusers community!
|
||||
<td><a href="https://github.com/Netwrck/stable-diffusion-server"> Stable Diffusion Server </a></td>
|
||||
<td>A server configured for Inpainting/Generation/img2img with one stable diffusion model</td>
|
||||
</tr>
|
||||
<tr style="border-top: 2px solid black">
|
||||
<td><a href="https://github.com/suzukimain/auto_diffusers"> Model Search </a></td>
|
||||
<td>Search models on Civitai and Hugging Face</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
@@ -25,10 +25,9 @@ Quantize a model by passing [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`]
|
||||
The example below only quantizes the weights to int8.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
|
||||
|
||||
model_id = "black-forest-labs/FLUX.1-dev"
|
||||
model_id = "black-forest-labs/Flux.1-Dev"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
quantization_config = TorchAoConfig("int8wo")
|
||||
@@ -45,14 +44,8 @@ pipe = FluxPipeline.from_pretrained(
|
||||
)
|
||||
pipe.to("cuda")
|
||||
|
||||
# Without quantization: ~31.447 GB
|
||||
# With quantization: ~20.40 GB
|
||||
print(f"Pipeline memory usage: {torch.cuda.max_memory_reserved() / 1024**3:.3f} GB")
|
||||
|
||||
prompt = "A cat holding a sign that says hello world"
|
||||
image = pipe(
|
||||
prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
|
||||
).images[0]
|
||||
image = pipe(prompt, num_inference_steps=28, guidance_scale=0.0).images[0]
|
||||
image.save("output.png")
|
||||
```
|
||||
|
||||
@@ -93,63 +86,6 @@ Some quantization methods are aliases (for example, `int8wo` is the commonly use
|
||||
|
||||
Refer to the official torchao documentation for a better understanding of the available quantization methods and the exhaustive list of configuration options available.
|
||||
|
||||
## Serializing and Deserializing quantized models
|
||||
|
||||
To serialize a quantized model in a given dtype, first load the model with the desired quantization dtype and then save it using the [`~ModelMixin.save_pretrained`] method.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxTransformer2DModel, TorchAoConfig
|
||||
|
||||
quantization_config = TorchAoConfig("int8wo")
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/Flux.1-Dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
transformer.save_pretrained("/path/to/flux_int8wo", safe_serialization=False)
|
||||
```
|
||||
|
||||
To load a serialized quantized model, use the [`~ModelMixin.from_pretrained`] method.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxPipeline, FluxTransformer2DModel
|
||||
|
||||
transformer = FluxTransformer2DModel.from_pretrained("/path/to/flux_int8wo", torch_dtype=torch.bfloat16, use_safetensors=False)
|
||||
pipe = FluxPipeline.from_pretrained("black-forest-labs/Flux.1-Dev", transformer=transformer, torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = "A cat holding a sign that says hello world"
|
||||
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0).images[0]
|
||||
image.save("output.png")
|
||||
```
|
||||
|
||||
Some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
|
||||
|
||||
# Serialize the model
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/Flux.1-Dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=TorchAoConfig("uint4wo"),
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
transformer.save_pretrained("/path/to/flux_uint4wo", safe_serialization=False, max_shard_size="50GB")
|
||||
# ...
|
||||
|
||||
# Load the model
|
||||
state_dict = torch.load("/path/to/flux_uint4wo/diffusion_pytorch_model.bin", weights_only=False, map_location="cpu")
|
||||
with init_empty_weights():
|
||||
transformer = FluxTransformer2DModel.from_config("/path/to/flux_uint4wo/config.json")
|
||||
transformer.load_state_dict(state_dict, strict=True, assign=True)
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- [TorchAO Quantization API](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
@@ -10,20 +10,31 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Video generation
|
||||
# Text or image-to-video
|
||||
|
||||
Video generation models include a temporal dimension to bring images, or frames, together to create a video. These models are trained on large-scale datasets of high-quality text-video pairs to learn how to combine the modalities to ensure the generated video is coherent and realistic.
|
||||
Driven by the success of text-to-image diffusion models, generative video models are able to generate short clips of video from a text prompt or an initial image. These models extend a pretrained diffusion model to generate videos by adding some type of temporal and/or spatial convolution layer to the architecture. A mixed dataset of images and videos are used to train the model which learns to output a series of video frames based on the text or image conditioning.
|
||||
|
||||
[Explore](https://huggingface.co/models?other=video-generation) some of the more popular open-source video generation models available from Diffusers below.
|
||||
This guide will show you how to generate videos, how to configure video model parameters, and how to control video generation.
|
||||
|
||||
<hfoptions id="popular-models">
|
||||
<hfoption id="CogVideoX">
|
||||
## Popular models
|
||||
|
||||
[CogVideoX](https://huggingface.co/collections/THUDM/cogvideo-66c08e62f1685a3ade464cce) uses a 3D causal Variational Autoencoder (VAE) to compress videos along the spatial and temporal dimensions, and it includes a stack of expert transformer blocks with a 3D full attention mechanism to better capture visual, semantic, and motion information in the data.
|
||||
> [!TIP]
|
||||
> Discover other cool and trending video generation models on the Hub [here](https://huggingface.co/models?pipeline_tag=text-to-video&sort=trending)!
|
||||
|
||||
The CogVideoX family also includes models capable of generating videos from images and videos in addition to text. The image-to-video models are indicated by **I2V** in the checkpoint name, and they should be used with the [`CogVideoXImageToVideoPipeline`]. The regular checkpoints support video-to-video through the [`CogVideoXVideoToVideoPipeline`].
|
||||
[Stable Video Diffusions (SVD)](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid), [I2VGen-XL](https://huggingface.co/ali-vilab/i2vgen-xl/), [AnimateDiff](https://huggingface.co/guoyww/animatediff), and [ModelScopeT2V](https://huggingface.co/ali-vilab/text-to-video-ms-1.7b) are popular models used for video diffusion. Each model is distinct. For example, AnimateDiff inserts a motion modeling module into a frozen text-to-image model to generate personalized animated images, whereas SVD is entirely pretrained from scratch with a three-stage training process to generate short high-quality videos.
|
||||
|
||||
The example below demonstrates how to generate a video from an image and text prompt with [THUDM/CogVideoX-5b-I2V](https://huggingface.co/THUDM/CogVideoX-5b-I2V).
|
||||
[CogVideoX](https://huggingface.co/collections/THUDM/cogvideo-66c08e62f1685a3ade464cce) is another popular video generation model. The model is a multidimensional transformer that integrates text, time, and space. It employs full attention in the attention module and includes an expert block at the layer level to spatially align text and video.
|
||||
|
||||
### CogVideoX
|
||||
|
||||
[CogVideoX](../api/pipelines/cogvideox) uses a 3D Variational Autoencoder (VAE) to compress videos along the spatial and temporal dimensions.
|
||||
|
||||
Begin by loading the [`CogVideoXPipeline`] and passing an initial text or image to generate a video.
|
||||
<Tip>
|
||||
|
||||
CogVideoX is available for image-to-video and text-to-video. [THUDM/CogVideoX-5b-I2V](https://huggingface.co/THUDM/CogVideoX-5b-I2V) uses the [`CogVideoXImageToVideoPipeline`] for image-to-video. [THUDM/CogVideoX-5b](https://huggingface.co/THUDM/CogVideoX-5b) and [THUDM/CogVideoX-2b](https://huggingface.co/THUDM/CogVideoX-2b) are available for text-to-video with the [`CogVideoXPipeline`].
|
||||
|
||||
</Tip>
|
||||
|
||||
```py
|
||||
import torch
|
||||
@@ -31,13 +42,12 @@ from diffusers import CogVideoXImageToVideoPipeline
|
||||
from diffusers.utils import export_to_video, load_image
|
||||
|
||||
prompt = "A vast, shimmering ocean flows gracefully under a twilight sky, its waves undulating in a mesmerizing dance of blues and greens. The surface glints with the last rays of the setting sun, casting golden highlights that ripple across the water. Seagulls soar above, their cries blending with the gentle roar of the waves. The horizon stretches infinitely, where the ocean meets the sky in a seamless blend of hues. Close-ups reveal the intricate patterns of the waves, capturing the fluidity and dynamic beauty of the sea in motion."
|
||||
image = load_image(image="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cogvideox/cogvideox_rocket.png")
|
||||
image = load_image(image="cogvideox_rocket.png")
|
||||
pipe = CogVideoXImageToVideoPipeline.from_pretrained(
|
||||
"THUDM/CogVideoX-5b-I2V",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
# reduce memory requirements
|
||||
|
||||
pipe.vae.enable_tiling()
|
||||
pipe.vae.enable_slicing()
|
||||
|
||||
@@ -50,6 +60,7 @@ video = pipe(
|
||||
guidance_scale=6,
|
||||
generator=torch.Generator(device="cuda").manual_seed(42),
|
||||
).frames[0]
|
||||
|
||||
export_to_video(video, "output.mp4", fps=8)
|
||||
```
|
||||
|
||||
@@ -64,103 +75,12 @@ export_to_video(video, "output.mp4", fps=8)
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="HunyuanVideo">
|
||||
|
||||
### Stable Video Diffusion
|
||||
|
||||
> [!TIP]
|
||||
> HunyuanVideo is a 13B parameter model and requires a lot of memory. Refer to the HunyuanVideo [Quantization](../api/pipelines/hunyuan_video#quantization) guide to learn how to quantize the model. CogVideoX and LTX-Video are more lightweight options that can still generate high-quality videos.
|
||||
[SVD](../api/pipelines/svd) is based on the Stable Diffusion 2.1 model and it is trained on images, then low-resolution videos, and finally a smaller dataset of high-resolution videos. This model generates a short 2-4 second video from an initial image. You can learn more details about model, like micro-conditioning, in the [Stable Video Diffusion](../using-diffusers/svd) guide.
|
||||
|
||||
[HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo) features a dual-stream to single-stream diffusion transformer (DiT) for learning video and text tokens separately, and then subsequently concatenating the video and text tokens to combine their information. A single multimodal large language model (MLLM) serves as the text encoder, and videos are also spatio-temporally compressed with a 3D causal VAE.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
|
||||
"tencent/HunyuanVideo", subfolder="transformer", torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe = HunyuanVideoPipeline.from_pretrained(
|
||||
"tencent/HunyuanVideo", transformer=transformer, torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
# reduce memory requirements
|
||||
pipe.vae.enable_tiling()
|
||||
pipe.to("cuda")
|
||||
|
||||
video = pipe(
|
||||
prompt="A cat walks on the grass, realistic",
|
||||
height=320,
|
||||
width=512,
|
||||
num_frames=61,
|
||||
num_inference_steps=30,
|
||||
).frames[0]
|
||||
export_to_video(video, "output.mp4", fps=15)
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hunyuan-video-output.gif"/>
|
||||
</div>
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="LTX-Video">
|
||||
|
||||
[LTX-Video (LTXV)](https://huggingface.co/Lightricks/LTX-Video) is a diffusion transformer (DiT) with a focus on speed. It generates 768x512 resolution videos at 24 frames per second (fps), enabling near real-time generation of high-quality videos. LTXV is relatively lightweight compared to other modern video generation models, making it possible to run on consumer GPUs.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import LTXPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16).to("cuda")
|
||||
|
||||
prompt = "A man walks towards a window, looks out, and then turns around. He has short, dark hair, dark skin, and is wearing a brown coat over a red and gray scarf. He walks from left to right towards a window, his gaze fixed on something outside. The camera follows him from behind at a medium distance. The room is brightly lit, with white walls and a large window covered by a white curtain. As he approaches the window, he turns his head slightly to the left, then back to the right. He then turns his entire body to the right, facing the window. The camera remains stationary as he stands in front of the window. The scene is captured in real-life footage."
|
||||
video = pipe(
|
||||
prompt=prompt,
|
||||
width=704,
|
||||
height=480,
|
||||
num_frames=161,
|
||||
num_inference_steps=50,
|
||||
).frames[0]
|
||||
export_to_video(video, "output.mp4", fps=24)
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/Lightricks/LTX-Video/resolve/main/media/ltx-video_example_00014.gif"/>
|
||||
</div>
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Mochi-1">
|
||||
|
||||
> [!TIP]
|
||||
> Mochi-1 is a 10B parameter model and requires a lot of memory. Refer to the Mochi [Quantization](../api/pipelines/mochi#quantization) guide to learn how to quantize the model. CogVideoX and LTX-Video are more lightweight options that can still generate high-quality videos.
|
||||
|
||||
[Mochi-1](https://huggingface.co/genmo/mochi-1-preview) introduces the Asymmetric Diffusion Transformer (AsymmDiT) and Asymmetric Variational Autoencoder (AsymmVAE) to reduces memory requirements. AsymmVAE causally compresses videos 128x to improve memory efficiency, and AsymmDiT jointly attends to the compressed video tokens and user text tokens. This model is noted for generating videos with high-quality motion dynamics and strong prompt adherence.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import MochiPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", variant="bf16", torch_dtype=torch.bfloat16)
|
||||
|
||||
# reduce memory requirements
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.enable_vae_tiling()
|
||||
|
||||
prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k."
|
||||
video = pipe(prompt, num_frames=84).frames[0]
|
||||
export_to_video(video, "output.mp4", fps=30)
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/mochi-video-output.gif"/>
|
||||
</div>
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="StableVideoDiffusion">
|
||||
|
||||
[StableVideoDiffusion (SVD)](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt) is based on the Stable Diffusion 2.1 model and it is trained on images, then low-resolution videos, and finally a smaller dataset of high-resolution videos. This model generates a short 2-4 second video from an initial image.
|
||||
Begin by loading the [`StableVideoDiffusionPipeline`] and passing an initial image to generate a video from.
|
||||
|
||||
```py
|
||||
import torch
|
||||
@@ -170,8 +90,6 @@ from diffusers.utils import load_image, export_to_video
|
||||
pipeline = StableVideoDiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16"
|
||||
)
|
||||
|
||||
# reduce memory requirements
|
||||
pipeline.enable_model_cpu_offload()
|
||||
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png")
|
||||
@@ -193,12 +111,54 @@ export_to_video(frames, "generated.mp4", fps=7)
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AnimateDiff">
|
||||
### I2VGen-XL
|
||||
|
||||
[AnimateDiff](https://huggingface.co/guoyww/animatediff) is an adapter model that inserts a motion module into a pretrained diffusion model to animate an image. The adapter is trained on video clips to learn motion which is used to condition the generation process to create a video. It is faster and easier to only train the adapter and it can be loaded into most diffusion models, effectively turning them into “video models”.
|
||||
[I2VGen-XL](../api/pipelines/i2vgenxl) is a diffusion model that can generate higher resolution videos than SVD and it is also capable of accepting text prompts in addition to images. The model is trained with two hierarchical encoders (detail and global encoder) to better capture low and high-level details in images. These learned details are used to train a video diffusion model which refines the video resolution and details in the generated video.
|
||||
|
||||
Load a `MotionAdapter` and pass it to the [`AnimateDiffPipeline`].
|
||||
You can use I2VGen-XL by loading the [`I2VGenXLPipeline`], and passing a text and image prompt to generate a video.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import I2VGenXLPipeline
|
||||
from diffusers.utils import export_to_gif, load_image
|
||||
|
||||
pipeline = I2VGenXLPipeline.from_pretrained("ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16")
|
||||
pipeline.enable_model_cpu_offload()
|
||||
|
||||
image_url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/i2vgen_xl_images/img_0009.png"
|
||||
image = load_image(image_url).convert("RGB")
|
||||
|
||||
prompt = "Papers were floating in the air on a table in the library"
|
||||
negative_prompt = "Distorted, discontinuous, Ugly, blurry, low resolution, motionless, static, disfigured, disconnected limbs, Ugly faces, incomplete arms"
|
||||
generator = torch.manual_seed(8888)
|
||||
|
||||
frames = pipeline(
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
num_inference_steps=50,
|
||||
negative_prompt=negative_prompt,
|
||||
guidance_scale=9.0,
|
||||
generator=generator
|
||||
).frames[0]
|
||||
export_to_gif(frames, "i2v.gif")
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/i2vgen_xl_images/img_0009.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">initial image</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/i2vgen-xl-example.gif"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">generated video</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
### AnimateDiff
|
||||
|
||||
[AnimateDiff](../api/pipelines/animatediff) is an adapter model that inserts a motion module into a pretrained diffusion model to animate an image. The adapter is trained on video clips to learn motion which is used to condition the generation process to create a video. It is faster and easier to only train the adapter and it can be loaded into most diffusion models, effectively turning them into "video models".
|
||||
|
||||
Start by loading a [`MotionAdapter`].
|
||||
|
||||
```py
|
||||
import torch
|
||||
@@ -206,6 +166,11 @@ from diffusers import AnimateDiffPipeline, DDIMScheduler, MotionAdapter
|
||||
from diffusers.utils import export_to_gif
|
||||
|
||||
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
Then load a finetuned Stable Diffusion model with the [`AnimateDiffPipeline`].
|
||||
|
||||
```py
|
||||
pipeline = AnimateDiffPipeline.from_pretrained("emilianJR/epiCRealism", motion_adapter=adapter, torch_dtype=torch.float16)
|
||||
scheduler = DDIMScheduler.from_pretrained(
|
||||
"emilianJR/epiCRealism",
|
||||
@@ -216,11 +181,13 @@ scheduler = DDIMScheduler.from_pretrained(
|
||||
steps_offset=1,
|
||||
)
|
||||
pipeline.scheduler = scheduler
|
||||
|
||||
# reduce memory requirements
|
||||
pipeline.enable_vae_slicing()
|
||||
pipeline.enable_model_cpu_offload()
|
||||
```
|
||||
|
||||
Create a prompt and generate the video.
|
||||
|
||||
```py
|
||||
output = pipeline(
|
||||
prompt="A space rocket with trails of smoke behind it launching into space from the desert, 4k, high resolution",
|
||||
negative_prompt="bad quality, worse quality, low resolution",
|
||||
@@ -234,11 +201,38 @@ export_to_gif(frames, "animation.gif")
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff.gif"/>
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff.gif"/>
|
||||
</div>
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
### ModelscopeT2V
|
||||
|
||||
[ModelscopeT2V](../api/pipelines/text_to_video) adds spatial and temporal convolutions and attention to a UNet, and it is trained on image-text and video-text datasets to enhance what it learns during training. The model takes a prompt, encodes it and creates text embeddings which are denoised by the UNet, and then decoded by a VQGAN into a video.
|
||||
|
||||
<Tip>
|
||||
|
||||
ModelScopeT2V generates watermarked videos due to the datasets it was trained on. To use a watermark-free model, try the [cerspense/zeroscope_v2_76w](https://huggingface.co/cerspense/zeroscope_v2_576w) model with the [`TextToVideoSDPipeline`] first, and then upscale it's output with the [cerspense/zeroscope_v2_XL](https://huggingface.co/cerspense/zeroscope_v2_XL) checkpoint using the [`VideoToVideoSDPipeline`].
|
||||
|
||||
</Tip>
|
||||
|
||||
Load a ModelScopeT2V checkpoint into the [`DiffusionPipeline`] along with a prompt to generate a video.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16")
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.enable_vae_slicing()
|
||||
|
||||
prompt = "Confident teddy bear surfer rides the wave in the tropics"
|
||||
video_frames = pipeline(prompt).frames[0]
|
||||
export_to_video(video_frames, "modelscopet2v.mp4", fps=10)
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/modelscopet2v.gif" />
|
||||
</div>
|
||||
|
||||
## Configure model parameters
|
||||
|
||||
@@ -554,9 +548,3 @@ If memory is not an issue and you want to optimize for speed, try wrapping the U
|
||||
+ pipeline.to("cuda")
|
||||
+ pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
|
||||
```
|
||||
|
||||
## Quantization
|
||||
|
||||
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
|
||||
|
||||
Refer to the [Quantization](../../quantization/overview) to learn more about supported quantization backends (bitsandbytes, torchao, gguf) and selecting a quantization backend that supports your use case.
|
||||
|
||||
@@ -74,7 +74,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -73,7 +73,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ from diffusers.utils import BaseOutput, check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
|
||||
class MarigoldDepthOutput(BaseOutput):
|
||||
|
||||
@@ -73,7 +73,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -78,7 +78,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
|
||||
@@ -59,7 +59,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
|
||||
@@ -63,7 +63,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -1,206 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import safetensors
|
||||
|
||||
|
||||
sys.path.append("..")
|
||||
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
class DreamBoothLoRASANA(ExamplesTestsAccelerate):
|
||||
instance_data_dir = "docs/source/en/imgs"
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-sana-pipe"
|
||||
script_path = "examples/dreambooth/train_dreambooth_lora_sana.py"
|
||||
transformer_layer_type = "transformer_blocks.0.attn1.to_k"
|
||||
|
||||
def test_dreambooth_lora_sana(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--resolution 32
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--max_sequence_length 16
|
||||
""".split()
|
||||
|
||||
test_args.extend(["--instance_prompt", ""])
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names.
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_latent_caching(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--resolution 32
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--cache_latents
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--max_sequence_length 16
|
||||
""".split()
|
||||
|
||||
test_args.extend(["--instance_prompt", ""])
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names.
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_layers(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--resolution 32
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--cache_latents
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lora_layers {self.transformer_layer_type}
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--max_sequence_length 16
|
||||
""".split()
|
||||
|
||||
test_args.extend(["--instance_prompt", ""])
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names. In this test, we only params of
|
||||
# `self.transformer_layer_type` should be in the state dict.
|
||||
starts_with_transformer = all(self.transformer_layer_type in key for key in lora_state_dict)
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_sana_checkpointing_checkpoints_total_limit(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--resolution=32
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=6
|
||||
--checkpoints_total_limit=2
|
||||
--checkpointing_steps=2
|
||||
--max_sequence_length 16
|
||||
""".split()
|
||||
|
||||
test_args.extend(["--instance_prompt", ""])
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_dreambooth_lora_sana_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--resolution=32
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=4
|
||||
--checkpointing_steps=2
|
||||
--max_sequence_length 166
|
||||
""".split()
|
||||
|
||||
test_args.extend(["--instance_prompt", ""])
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
|
||||
|
||||
resume_run_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--resolution=32
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=8
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-4
|
||||
--checkpoints_total_limit=2
|
||||
--max_sequence_length 16
|
||||
""".split()
|
||||
|
||||
resume_run_args.extend(["--instance_prompt", ""])
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
|
||||
@@ -63,7 +63,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
# Cache compiled models across invocations of this script.
|
||||
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
||||
|
||||
@@ -65,7 +65,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -70,7 +70,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -70,7 +70,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -943,7 +943,7 @@ def main(args):
|
||||
|
||||
# Load scheduler and models
|
||||
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="scheduler", revision=args.revision
|
||||
args.pretrained_model_name_or_path, subfolder="scheduler"
|
||||
)
|
||||
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
||||
text_encoder = Gemma2Model.from_pretrained(
|
||||
@@ -964,6 +964,15 @@ def main(args):
|
||||
vae.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
|
||||
# Initialize a text encoding pipeline and keep it to CPU for now.
|
||||
text_encoding_pipeline = SanaPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=None,
|
||||
transformer=None,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
|
||||
# as these weights are only used for inference, keeping weights in full precision is not required.
|
||||
weight_dtype = torch.float32
|
||||
@@ -984,15 +993,6 @@ def main(args):
|
||||
# because Gemma2 is particularly suited for bfloat16.
|
||||
text_encoder.to(dtype=torch.bfloat16)
|
||||
|
||||
# Initialize a text encoding pipeline and keep it to CPU for now.
|
||||
text_encoding_pipeline = SanaPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=None,
|
||||
transformer=None,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
transformer.enable_gradient_checkpointing()
|
||||
|
||||
@@ -1182,7 +1182,6 @@ def main(args):
|
||||
)
|
||||
if args.offload:
|
||||
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
|
||||
prompt_embeds = prompt_embeds.to(transformer.dtype)
|
||||
return prompt_embeds, prompt_attention_mask
|
||||
|
||||
# If no type of tuning is done on the text_encoder and custom instance prompts are NOT
|
||||
@@ -1217,7 +1216,7 @@ def main(args):
|
||||
vae_config_scaling_factor = vae.config.scaling_factor
|
||||
if args.cache_latents:
|
||||
latents_cache = []
|
||||
vae = vae.to(accelerator.device)
|
||||
vae = vae.to("cuda")
|
||||
for batch in tqdm(train_dataloader, desc="Caching latents"):
|
||||
with torch.no_grad():
|
||||
batch["pixel_values"] = batch["pixel_values"].to(
|
||||
|
||||
@@ -29,7 +29,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
import transformers
|
||||
from accelerate import Accelerator, DistributedType
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
@@ -72,7 +72,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -1292,17 +1292,11 @@ def main(args):
|
||||
text_encoder_two_lora_layers_to_save = None
|
||||
|
||||
for model in models:
|
||||
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
|
||||
model = unwrap_model(model)
|
||||
if args.upcast_before_saving:
|
||||
model = model.to(torch.float32)
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
elif args.train_text_encoder and isinstance(
|
||||
unwrap_model(model), type(unwrap_model(text_encoder_one))
|
||||
): # or text_encoder_two
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_one))): # or text_encoder_two
|
||||
# both text encoders are of the same class, so we check hidden size to distinguish between the two
|
||||
model = unwrap_model(model)
|
||||
hidden_size = model.config.hidden_size
|
||||
hidden_size = unwrap_model(model).config.hidden_size
|
||||
if hidden_size == 768:
|
||||
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
elif hidden_size == 1280:
|
||||
@@ -1311,8 +1305,7 @@ def main(args):
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
if weights:
|
||||
weights.pop()
|
||||
weights.pop()
|
||||
|
||||
StableDiffusion3Pipeline.save_lora_weights(
|
||||
output_dir,
|
||||
@@ -1326,31 +1319,17 @@ def main(args):
|
||||
text_encoder_one_ = None
|
||||
text_encoder_two_ = None
|
||||
|
||||
if not accelerator.distributed_type == DistributedType.DEEPSPEED:
|
||||
while len(models) > 0:
|
||||
model = models.pop()
|
||||
while len(models) > 0:
|
||||
model = models.pop()
|
||||
|
||||
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
|
||||
transformer_ = unwrap_model(model)
|
||||
elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_ = unwrap_model(model)
|
||||
elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_two))):
|
||||
text_encoder_two_ = unwrap_model(model)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
else:
|
||||
transformer_ = SD3Transformer2DModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="transformer"
|
||||
)
|
||||
transformer_.add_adapter(transformer_lora_config)
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one_ = text_encoder_cls_one.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder"
|
||||
)
|
||||
text_encoder_two_ = text_encoder_cls_two.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder_2"
|
||||
)
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_ = model
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_ = model
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_two))):
|
||||
text_encoder_two_ = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
|
||||
|
||||
@@ -1850,7 +1829,7 @@ def main(args):
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
|
||||
if accelerator.is_main_process:
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
||||
if args.checkpoints_total_limit is not None:
|
||||
|
||||
@@ -79,7 +79,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -795,7 +795,7 @@ def main(args):
|
||||
flux_transformer.x_embedder = new_linear
|
||||
|
||||
assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
|
||||
flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)
|
||||
flux_transformer.register_to_config(in_channels=initial_input_channels * 2)
|
||||
|
||||
def unwrap_model(model):
|
||||
model = accelerator.unwrap_model(model)
|
||||
@@ -1166,11 +1166,6 @@ def main(args):
|
||||
flux_transformer.to(torch.float32)
|
||||
flux_transformer.save_pretrained(args.output_dir)
|
||||
|
||||
del flux_transformer
|
||||
del text_encoding_pipeline
|
||||
del vae
|
||||
free_memory()
|
||||
|
||||
# Run a final round of validation.
|
||||
image_logs = None
|
||||
if args.validation_prompt is not None:
|
||||
|
||||
@@ -57,7 +57,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -830,7 +830,7 @@ def main(args):
|
||||
flux_transformer.x_embedder = new_linear
|
||||
|
||||
assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
|
||||
flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)
|
||||
flux_transformer.register_to_config(in_channels=initial_input_channels * 2)
|
||||
|
||||
if args.train_norm_layers:
|
||||
for name, param in flux_transformer.named_parameters():
|
||||
@@ -923,28 +923,11 @@ def main(args):
|
||||
transformer_ = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
else:
|
||||
transformer_ = FluxTransformer2DModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="transformer"
|
||||
).to(accelerator.device, weight_dtype)
|
||||
|
||||
# Handle input dimension doubling before adding adapter
|
||||
with torch.no_grad():
|
||||
initial_input_channels = transformer_.config.in_channels
|
||||
new_linear = torch.nn.Linear(
|
||||
transformer_.x_embedder.in_features * 2,
|
||||
transformer_.x_embedder.out_features,
|
||||
bias=transformer_.x_embedder.bias is not None,
|
||||
dtype=transformer_.dtype,
|
||||
device=transformer_.device,
|
||||
)
|
||||
new_linear.weight.zero_()
|
||||
new_linear.weight[:, :initial_input_channels].copy_(transformer_.x_embedder.weight)
|
||||
if transformer_.x_embedder.bias is not None:
|
||||
new_linear.bias.copy_(transformer_.x_embedder.bias)
|
||||
transformer_.x_embedder = new_linear
|
||||
transformer_.register_to_config(in_channels=initial_input_channels * 2)
|
||||
|
||||
transformer_.add_adapter(transformer_lora_config)
|
||||
|
||||
lora_state_dict = FluxControlPipeline.lora_state_dict(input_dir)
|
||||
@@ -1336,11 +1319,6 @@ def main(args):
|
||||
transformer_lora_layers=transformer_lora_layers,
|
||||
)
|
||||
|
||||
del flux_transformer
|
||||
del text_encoding_pipeline
|
||||
del vae
|
||||
free_memory()
|
||||
|
||||
# Run a final round of validation.
|
||||
image_logs = None
|
||||
if args.validation_prompt is not None:
|
||||
|
||||
@@ -57,7 +57,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
|
||||
@@ -55,7 +55,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
|
||||
@@ -81,7 +81,7 @@ else:
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ else:
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -76,7 +76,7 @@ else:
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from safetensors.torch import load_file
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
|
||||
@@ -23,9 +21,7 @@ TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"vae": remove_keys_,
|
||||
}
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
VAE_KEYS_RENAME_DICT = {
|
||||
# decoder
|
||||
@@ -58,31 +54,10 @@ VAE_KEYS_RENAME_DICT = {
|
||||
"per_channel_statistics.std-of-means": "latents_std",
|
||||
}
|
||||
|
||||
VAE_091_RENAME_DICT = {
|
||||
# decoder
|
||||
"up_blocks.0": "mid_block",
|
||||
"up_blocks.1": "up_blocks.0.upsamplers.0",
|
||||
"up_blocks.2": "up_blocks.0",
|
||||
"up_blocks.3": "up_blocks.1.upsamplers.0",
|
||||
"up_blocks.4": "up_blocks.1",
|
||||
"up_blocks.5": "up_blocks.2.upsamplers.0",
|
||||
"up_blocks.6": "up_blocks.2",
|
||||
"up_blocks.7": "up_blocks.3.upsamplers.0",
|
||||
"up_blocks.8": "up_blocks.3",
|
||||
# common
|
||||
"last_time_embedder": "time_embedder",
|
||||
"last_scale_shift_table": "scale_shift_table",
|
||||
}
|
||||
|
||||
VAE_SPECIAL_KEYS_REMAP = {
|
||||
"per_channel_statistics.channel": remove_keys_,
|
||||
"per_channel_statistics.mean-of-means": remove_keys_,
|
||||
"per_channel_statistics.mean-of-stds": remove_keys_,
|
||||
"model.diffusion_model": remove_keys_,
|
||||
}
|
||||
|
||||
VAE_091_SPECIAL_KEYS_REMAP = {
|
||||
"timestep_scale_multiplier": remove_keys_,
|
||||
}
|
||||
|
||||
|
||||
@@ -105,16 +80,13 @@ def convert_transformer(
|
||||
ckpt_path: str,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
PREFIX_KEY = "model.diffusion_model."
|
||||
PREFIX_KEY = ""
|
||||
|
||||
original_state_dict = get_state_dict(load_file(ckpt_path))
|
||||
with init_empty_weights():
|
||||
transformer = LTXVideoTransformer3DModel()
|
||||
transformer = LTXVideoTransformer3DModel().to(dtype=dtype)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
if new_key.startswith(PREFIX_KEY):
|
||||
new_key = key[len(PREFIX_KEY) :]
|
||||
new_key = key[len(PREFIX_KEY) :]
|
||||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
@@ -125,21 +97,16 @@ def convert_transformer(
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
transformer.load_state_dict(original_state_dict, strict=True)
|
||||
return transformer
|
||||
|
||||
|
||||
def convert_vae(ckpt_path: str, config, dtype: torch.dtype):
|
||||
PREFIX_KEY = "vae."
|
||||
|
||||
def convert_vae(ckpt_path: str, dtype: torch.dtype):
|
||||
original_state_dict = get_state_dict(load_file(ckpt_path))
|
||||
with init_empty_weights():
|
||||
vae = AutoencoderKLLTXVideo(**config)
|
||||
vae = AutoencoderKLLTXVideo().to(dtype=dtype)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
if new_key.startswith(PREFIX_KEY):
|
||||
new_key = key[len(PREFIX_KEY) :]
|
||||
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
@@ -150,60 +117,10 @@ def convert_vae(ckpt_path: str, config, dtype: torch.dtype):
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
vae.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
vae.load_state_dict(original_state_dict, strict=True)
|
||||
return vae
|
||||
|
||||
|
||||
def get_vae_config(version: str) -> Dict[str, Any]:
|
||||
if version == "0.9.0":
|
||||
config = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (128, 256, 512, 512),
|
||||
"decoder_block_out_channels": (128, 256, 512, 512),
|
||||
"layers_per_block": (4, 3, 3, 3, 4),
|
||||
"decoder_layers_per_block": (4, 3, 3, 3, 4),
|
||||
"spatio_temporal_scaling": (True, True, True, False),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True, False),
|
||||
"decoder_inject_noise": (False, False, False, False, False),
|
||||
"upsample_residual": (False, False, False, False),
|
||||
"upsample_factor": (1, 1, 1, 1),
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
"scaling_factor": 1.0,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
"timestep_conditioning": False,
|
||||
}
|
||||
elif version == "0.9.1":
|
||||
config = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (128, 256, 512, 512),
|
||||
"decoder_block_out_channels": (256, 512, 1024),
|
||||
"layers_per_block": (4, 3, 3, 3, 4),
|
||||
"decoder_layers_per_block": (5, 6, 7, 8),
|
||||
"spatio_temporal_scaling": (True, True, True, False),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True),
|
||||
"decoder_inject_noise": (True, True, True, False),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": True,
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
"scaling_factor": 1.0,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
}
|
||||
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
|
||||
VAE_SPECIAL_KEYS_REMAP.update(VAE_091_SPECIAL_KEYS_REMAP)
|
||||
return config
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
@@ -222,9 +139,6 @@ def get_args():
|
||||
parser.add_argument("--save_pipeline", action="store_true")
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
|
||||
parser.add_argument(
|
||||
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1"], help="Version of the LTX model"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -247,7 +161,6 @@ if __name__ == "__main__":
|
||||
transformer = None
|
||||
dtype = DTYPE_MAPPING[args.dtype]
|
||||
variant = VARIANT_MAPPING[args.dtype]
|
||||
output_path = Path(args.output_path)
|
||||
|
||||
if args.save_pipeline:
|
||||
assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None
|
||||
@@ -256,14 +169,13 @@ if __name__ == "__main__":
|
||||
transformer: LTXVideoTransformer3DModel = convert_transformer(args.transformer_ckpt_path, dtype)
|
||||
if not args.save_pipeline:
|
||||
transformer.save_pretrained(
|
||||
output_path / "transformer", safe_serialization=True, max_shard_size="5GB", variant=variant
|
||||
args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant
|
||||
)
|
||||
|
||||
if args.vae_ckpt_path is not None:
|
||||
config = get_vae_config(args.version)
|
||||
vae: AutoencoderKLLTXVideo = convert_vae(args.vae_ckpt_path, config, dtype)
|
||||
vae: AutoencoderKLLTXVideo = convert_vae(args.vae_ckpt_path, dtype)
|
||||
if not args.save_pipeline:
|
||||
vae.save_pretrained(output_path / "vae", safe_serialization=True, max_shard_size="5GB", variant=variant)
|
||||
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant)
|
||||
|
||||
if args.save_pipeline:
|
||||
text_encoder_id = "google/t5-v1_1-xxl"
|
||||
|
||||
@@ -88,18 +88,13 @@ def main(args):
|
||||
# y norm
|
||||
converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
|
||||
|
||||
# scheduler
|
||||
flow_shift = 3.0
|
||||
|
||||
# model config
|
||||
if args.model_type == "SanaMS_1600M_P1_D20":
|
||||
layer_num = 20
|
||||
elif args.model_type == "SanaMS_600M_P1_D28":
|
||||
layer_num = 28
|
||||
else:
|
||||
raise ValueError(f"{args.model_type} is not supported.")
|
||||
# Positional embedding interpolation scale.
|
||||
interpolation_scale = {512: None, 1024: None, 2048: 1.0}
|
||||
|
||||
for depth in range(layer_num):
|
||||
# Transformer blocks.
|
||||
@@ -181,7 +176,6 @@ def main(args):
|
||||
patch_size=1,
|
||||
norm_elementwise_affine=False,
|
||||
norm_eps=1e-6,
|
||||
interpolation_scale=interpolation_scale[args.image_size],
|
||||
)
|
||||
|
||||
if is_accelerate_available():
|
||||
|
||||
@@ -254,7 +254,7 @@ version_range_max = max(sys.version_info[1], 10) + 1
|
||||
|
||||
setup(
|
||||
name="diffusers",
|
||||
version="0.33.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="0.32.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
description="State-of-the-art diffusion in PyTorch and JAX.",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
__version__ = "0.33.0.dev0"
|
||||
__version__ = "0.32.0.dev0"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
@@ -2286,50 +2286,6 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
|
||||
transformer._transformer_norm_layers = None
|
||||
|
||||
if getattr(transformer, "_overwritten_params", None) is not None:
|
||||
overwritten_params = transformer._overwritten_params
|
||||
module_names = set()
|
||||
|
||||
for param_name in overwritten_params:
|
||||
if param_name.endswith(".weight"):
|
||||
module_names.add(param_name.replace(".weight", ""))
|
||||
|
||||
for name, module in transformer.named_modules():
|
||||
if isinstance(module, torch.nn.Linear) and name in module_names:
|
||||
module_weight = module.weight.data
|
||||
module_bias = module.bias.data if module.bias is not None else None
|
||||
bias = module_bias is not None
|
||||
|
||||
parent_module_name, _, current_module_name = name.rpartition(".")
|
||||
parent_module = transformer.get_submodule(parent_module_name)
|
||||
|
||||
current_param_weight = overwritten_params[f"{name}.weight"]
|
||||
in_features, out_features = current_param_weight.shape[1], current_param_weight.shape[0]
|
||||
with torch.device("meta"):
|
||||
original_module = torch.nn.Linear(
|
||||
in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
dtype=module_weight.dtype,
|
||||
)
|
||||
|
||||
tmp_state_dict = {"weight": current_param_weight}
|
||||
if module_bias is not None:
|
||||
tmp_state_dict.update({"bias": overwritten_params[f"{name}.bias"]})
|
||||
original_module.load_state_dict(tmp_state_dict, assign=True, strict=True)
|
||||
setattr(parent_module, current_module_name, original_module)
|
||||
|
||||
del tmp_state_dict
|
||||
|
||||
if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
|
||||
attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
|
||||
new_value = int(current_param_weight.shape[1])
|
||||
old_value = getattr(transformer.config, attribute_name)
|
||||
setattr(transformer.config, attribute_name, new_value)
|
||||
logger.info(
|
||||
f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _maybe_expand_transformer_param_shape_or_error_(
|
||||
cls,
|
||||
@@ -2356,8 +2312,6 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
# Expand transformer parameter shapes if they don't match lora
|
||||
has_param_with_shape_update = False
|
||||
overwritten_params = {}
|
||||
|
||||
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
|
||||
for name, module in transformer.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
@@ -2432,16 +2386,6 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
|
||||
)
|
||||
|
||||
# For `unload_lora_weights()`.
|
||||
# TODO: this could lead to more memory overhead if the number of overwritten params
|
||||
# are large. Should be revisited later and tackled through a `discard_original_layers` arg.
|
||||
overwritten_params[f"{current_module_name}.weight"] = module_weight
|
||||
if module_bias is not None:
|
||||
overwritten_params[f"{current_module_name}.bias"] = module_bias
|
||||
|
||||
if len(overwritten_params) > 0:
|
||||
transformer._overwritten_params = overwritten_params
|
||||
|
||||
return has_param_with_shape_update
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -28,7 +28,6 @@ from .single_file_utils import (
|
||||
convert_autoencoder_dc_checkpoint_to_diffusers,
|
||||
convert_controlnet_checkpoint,
|
||||
convert_flux_transformer_checkpoint_to_diffusers,
|
||||
convert_hunyuan_video_transformer_to_diffusers,
|
||||
convert_ldm_unet_checkpoint,
|
||||
convert_ldm_vae_checkpoint,
|
||||
convert_ltx_transformer_checkpoint_to_diffusers,
|
||||
@@ -102,10 +101,6 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"HunyuanVideoTransformer3DModel": {
|
||||
"checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -225,7 +220,6 @@ class FromOriginalModelMixin:
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
config_revision = kwargs.pop("config_revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
quantization_config = kwargs.pop("quantization_config", None)
|
||||
device = kwargs.pop("device", None)
|
||||
@@ -303,7 +297,7 @@ class FromOriginalModelMixin:
|
||||
subfolder=subfolder,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=config_revision,
|
||||
revision=revision,
|
||||
)
|
||||
expected_kwargs, optional_kwargs = cls._get_signature_keys(cls)
|
||||
|
||||
|
||||
@@ -108,7 +108,6 @@ CHECKPOINT_KEY_NAMES = {
|
||||
"autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias",
|
||||
"autoencoder-dc-sana": "encoder.project_in.conv.bias",
|
||||
"mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"],
|
||||
"hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
|
||||
}
|
||||
|
||||
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
@@ -157,14 +156,12 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"},
|
||||
"flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
|
||||
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
|
||||
"ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"},
|
||||
"ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"},
|
||||
"ltx-video": {"pretrained_model_name_or_path": "Lightricks/LTX-Video"},
|
||||
"autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"},
|
||||
"autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"},
|
||||
"autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
|
||||
"autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"},
|
||||
"mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"},
|
||||
"hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"},
|
||||
}
|
||||
|
||||
# Use to configure model sample size when original config is provided
|
||||
@@ -606,10 +603,7 @@ def infer_diffusers_model_type(checkpoint):
|
||||
model_type = "flux-schnell"
|
||||
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx-video"]):
|
||||
if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in checkpoint:
|
||||
model_type = "ltx-video-0.9.1"
|
||||
else:
|
||||
model_type = "ltx-video"
|
||||
model_type = "ltx-video"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["autoencoder-dc"] in checkpoint:
|
||||
encoder_key = "encoder.project_in.conv.conv.bias"
|
||||
@@ -630,9 +624,6 @@ def infer_diffusers_model_type(checkpoint):
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["mochi-1-preview"]):
|
||||
model_type = "mochi-1-preview"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint:
|
||||
model_type = "hunyuan-video"
|
||||
|
||||
else:
|
||||
model_type = "v1"
|
||||
|
||||
@@ -2342,32 +2333,12 @@ def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
"per_channel_statistics.std-of-means": "latents_std",
|
||||
}
|
||||
|
||||
VAE_091_RENAME_DICT = {
|
||||
# decoder
|
||||
"up_blocks.0": "mid_block",
|
||||
"up_blocks.1": "up_blocks.0.upsamplers.0",
|
||||
"up_blocks.2": "up_blocks.0",
|
||||
"up_blocks.3": "up_blocks.1.upsamplers.0",
|
||||
"up_blocks.4": "up_blocks.1",
|
||||
"up_blocks.5": "up_blocks.2.upsamplers.0",
|
||||
"up_blocks.6": "up_blocks.2",
|
||||
"up_blocks.7": "up_blocks.3.upsamplers.0",
|
||||
"up_blocks.8": "up_blocks.3",
|
||||
# common
|
||||
"last_time_embedder": "time_embedder",
|
||||
"last_scale_shift_table": "scale_shift_table",
|
||||
}
|
||||
|
||||
VAE_SPECIAL_KEYS_REMAP = {
|
||||
"per_channel_statistics.channel": remove_keys_,
|
||||
"per_channel_statistics.mean-of-means": remove_keys_,
|
||||
"per_channel_statistics.mean-of-stds": remove_keys_,
|
||||
"timestep_scale_multiplier": remove_keys_,
|
||||
}
|
||||
|
||||
if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict:
|
||||
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
|
||||
|
||||
for key in list(converted_state_dict.keys()):
|
||||
new_key = key
|
||||
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
|
||||
@@ -2551,133 +2522,3 @@ def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
new_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies")
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def convert_hunyuan_video_transformer_to_diffusers(checkpoint, **kwargs):
|
||||
def remap_norm_scale_shift_(key, state_dict):
|
||||
weight = state_dict.pop(key)
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight
|
||||
|
||||
def remap_txt_in_(key, state_dict):
|
||||
def rename_key(key):
|
||||
new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
|
||||
new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
|
||||
new_key = new_key.replace("txt_in", "context_embedder")
|
||||
new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
|
||||
new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
|
||||
new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
|
||||
new_key = new_key.replace("mlp", "ff")
|
||||
return new_key
|
||||
|
||||
if "self_attn_qkv" in key:
|
||||
weight = state_dict.pop(key)
|
||||
to_q, to_k, to_v = weight.chunk(3, dim=0)
|
||||
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
|
||||
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
|
||||
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
|
||||
else:
|
||||
state_dict[rename_key(key)] = state_dict.pop(key)
|
||||
|
||||
def remap_img_attn_qkv_(key, state_dict):
|
||||
weight = state_dict.pop(key)
|
||||
to_q, to_k, to_v = weight.chunk(3, dim=0)
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v
|
||||
|
||||
def remap_txt_attn_qkv_(key, state_dict):
|
||||
weight = state_dict.pop(key)
|
||||
to_q, to_k, to_v = weight.chunk(3, dim=0)
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v
|
||||
|
||||
def remap_single_transformer_blocks_(key, state_dict):
|
||||
hidden_size = 3072
|
||||
|
||||
if "linear1.weight" in key:
|
||||
linear1_weight = state_dict.pop(key)
|
||||
split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
|
||||
q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
|
||||
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.weight")
|
||||
state_dict[f"{new_key}.attn.to_q.weight"] = q
|
||||
state_dict[f"{new_key}.attn.to_k.weight"] = k
|
||||
state_dict[f"{new_key}.attn.to_v.weight"] = v
|
||||
state_dict[f"{new_key}.proj_mlp.weight"] = mlp
|
||||
|
||||
elif "linear1.bias" in key:
|
||||
linear1_bias = state_dict.pop(key)
|
||||
split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
|
||||
q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
|
||||
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.bias")
|
||||
state_dict[f"{new_key}.attn.to_q.bias"] = q_bias
|
||||
state_dict[f"{new_key}.attn.to_k.bias"] = k_bias
|
||||
state_dict[f"{new_key}.attn.to_v.bias"] = v_bias
|
||||
state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias
|
||||
|
||||
else:
|
||||
new_key = key.replace("single_blocks", "single_transformer_blocks")
|
||||
new_key = new_key.replace("linear2", "proj_out")
|
||||
new_key = new_key.replace("q_norm", "attn.norm_q")
|
||||
new_key = new_key.replace("k_norm", "attn.norm_k")
|
||||
state_dict[new_key] = state_dict.pop(key)
|
||||
|
||||
TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"img_in": "x_embedder",
|
||||
"time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
|
||||
"time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
|
||||
"guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
|
||||
"guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
|
||||
"vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
|
||||
"vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
|
||||
"double_blocks": "transformer_blocks",
|
||||
"img_attn_q_norm": "attn.norm_q",
|
||||
"img_attn_k_norm": "attn.norm_k",
|
||||
"img_attn_proj": "attn.to_out.0",
|
||||
"txt_attn_q_norm": "attn.norm_added_q",
|
||||
"txt_attn_k_norm": "attn.norm_added_k",
|
||||
"txt_attn_proj": "attn.to_add_out",
|
||||
"img_mod.linear": "norm1.linear",
|
||||
"img_norm1": "norm1.norm",
|
||||
"img_norm2": "norm2",
|
||||
"img_mlp": "ff",
|
||||
"txt_mod.linear": "norm1_context.linear",
|
||||
"txt_norm1": "norm1.norm",
|
||||
"txt_norm2": "norm2_context",
|
||||
"txt_mlp": "ff_context",
|
||||
"self_attn_proj": "attn.to_out.0",
|
||||
"modulation.linear": "norm.linear",
|
||||
"pre_norm": "norm.norm",
|
||||
"final_layer.norm_final": "norm_out.norm",
|
||||
"final_layer.linear": "proj_out",
|
||||
"fc1": "net.0.proj",
|
||||
"fc2": "net.2",
|
||||
"input_embedder": "proj_in",
|
||||
}
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"txt_in": remap_txt_in_,
|
||||
"img_attn_qkv": remap_img_attn_qkv_,
|
||||
"txt_attn_qkv": remap_txt_attn_qkv_,
|
||||
"single_blocks": remap_single_transformer_blocks_,
|
||||
"final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
|
||||
}
|
||||
|
||||
def update_state_dict_(state_dict, old_key, new_key):
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
for key in list(checkpoint.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_(checkpoint, key, new_key)
|
||||
|
||||
for key in list(checkpoint.keys()):
|
||||
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, checkpoint)
|
||||
|
||||
return checkpoint
|
||||
|
||||
@@ -177,5 +177,3 @@ class FluxTransformer2DLoadersMixin:
|
||||
|
||||
self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
|
||||
self.config.encoder_hid_dim_type = "ip_image_proj"
|
||||
|
||||
self.to(dtype=self.dtype, device=self.device)
|
||||
|
||||
@@ -4839,8 +4839,6 @@ class IPAdapterAttnProcessor(nn.Module):
|
||||
)
|
||||
else:
|
||||
for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
|
||||
if mask is None:
|
||||
continue
|
||||
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
|
||||
raise ValueError(
|
||||
"Each element of the ip_adapter_masks array should be a tensor with shape "
|
||||
@@ -5058,8 +5056,6 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
|
||||
)
|
||||
else:
|
||||
for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
|
||||
if mask is None:
|
||||
continue
|
||||
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
|
||||
raise ValueError(
|
||||
"Each element of the ip_adapter_masks array should be a tensor with shape "
|
||||
|
||||
@@ -22,14 +22,13 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..activations import get_activation
|
||||
from ..embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import RMSNorm
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution
|
||||
|
||||
|
||||
class LTXVideoCausalConv3d(nn.Module):
|
||||
class LTXCausalConv3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
@@ -80,9 +79,9 @@ class LTXVideoCausalConv3d(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXVideoResnetBlock3d(nn.Module):
|
||||
class LTXResnetBlock3d(nn.Module):
|
||||
r"""
|
||||
A 3D ResNet block used in the LTXVideo model.
|
||||
A 3D ResNet block used in the LTX model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`):
|
||||
@@ -110,9 +109,7 @@ class LTXVideoResnetBlock3d(nn.Module):
|
||||
elementwise_affine: bool = False,
|
||||
non_linearity: str = "swish",
|
||||
is_causal: bool = True,
|
||||
inject_noise: bool = False,
|
||||
timestep_conditioning: bool = False,
|
||||
) -> None:
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
out_channels = out_channels or in_channels
|
||||
@@ -120,13 +117,13 @@ class LTXVideoResnetBlock3d(nn.Module):
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
self.norm1 = RMSNorm(in_channels, eps=1e-8, elementwise_affine=elementwise_affine)
|
||||
self.conv1 = LTXVideoCausalConv3d(
|
||||
self.conv1 = LTXCausalConv3d(
|
||||
in_channels=in_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal
|
||||
)
|
||||
|
||||
self.norm2 = RMSNorm(out_channels, eps=1e-8, elementwise_affine=elementwise_affine)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.conv2 = LTXVideoCausalConv3d(
|
||||
self.conv2 = LTXCausalConv3d(
|
||||
in_channels=out_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal
|
||||
)
|
||||
|
||||
@@ -134,58 +131,22 @@ class LTXVideoResnetBlock3d(nn.Module):
|
||||
self.conv_shortcut = None
|
||||
if in_channels != out_channels:
|
||||
self.norm3 = nn.LayerNorm(in_channels, eps=eps, elementwise_affine=True, bias=True)
|
||||
self.conv_shortcut = LTXVideoCausalConv3d(
|
||||
self.conv_shortcut = LTXCausalConv3d(
|
||||
in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, is_causal=is_causal
|
||||
)
|
||||
|
||||
self.per_channel_scale1 = None
|
||||
self.per_channel_scale2 = None
|
||||
if inject_noise:
|
||||
self.per_channel_scale1 = nn.Parameter(torch.zeros(in_channels, 1, 1))
|
||||
self.per_channel_scale2 = nn.Parameter(torch.zeros(in_channels, 1, 1))
|
||||
|
||||
self.scale_shift_table = None
|
||||
if timestep_conditioning:
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(4, in_channels) / in_channels**0.5)
|
||||
|
||||
def forward(
|
||||
self, inputs: torch.Tensor, temb: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None
|
||||
) -> torch.Tensor:
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = inputs
|
||||
|
||||
hidden_states = self.norm1(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
|
||||
if self.scale_shift_table is not None:
|
||||
temb = temb.unflatten(1, (4, -1)) + self.scale_shift_table[None, ..., None, None, None]
|
||||
shift_1, scale_1, shift_2, scale_2 = temb.unbind(dim=1)
|
||||
hidden_states = hidden_states * (1 + scale_1) + shift_1
|
||||
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
|
||||
if self.per_channel_scale1 is not None:
|
||||
spatial_shape = hidden_states.shape[-2:]
|
||||
spatial_noise = torch.randn(
|
||||
spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype
|
||||
)[None]
|
||||
hidden_states = hidden_states + (spatial_noise * self.per_channel_scale1)[None, :, None, ...]
|
||||
|
||||
hidden_states = self.norm2(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
|
||||
if self.scale_shift_table is not None:
|
||||
hidden_states = hidden_states * (1 + scale_2) + shift_2
|
||||
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
if self.per_channel_scale2 is not None:
|
||||
spatial_shape = hidden_states.shape[-2:]
|
||||
spatial_noise = torch.randn(
|
||||
spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype
|
||||
)[None]
|
||||
hidden_states = hidden_states + (spatial_noise * self.per_channel_scale2)[None, :, None, ...]
|
||||
|
||||
if self.norm3 is not None:
|
||||
inputs = self.norm3(inputs.movedim(1, -1)).movedim(-1, 1)
|
||||
|
||||
@@ -196,24 +157,20 @@ class LTXVideoResnetBlock3d(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXVideoUpsampler3d(nn.Module):
|
||||
class LTXUpsampler3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
is_causal: bool = True,
|
||||
residual: bool = False,
|
||||
upscale_factor: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
|
||||
self.residual = residual
|
||||
self.upscale_factor = upscale_factor
|
||||
|
||||
out_channels = (in_channels * stride[0] * stride[1] * stride[2]) // upscale_factor
|
||||
out_channels = in_channels * stride[0] * stride[1] * stride[2]
|
||||
|
||||
self.conv = LTXVideoCausalConv3d(
|
||||
self.conv = LTXCausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
@@ -224,15 +181,6 @@ class LTXVideoUpsampler3d(nn.Module):
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
|
||||
if self.residual:
|
||||
residual = hidden_states.reshape(
|
||||
batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width
|
||||
)
|
||||
residual = residual.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
repeats = (self.stride[0] * self.stride[1] * self.stride[2]) // self.upscale_factor
|
||||
residual = residual.repeat(1, repeats, 1, 1, 1)
|
||||
residual = residual[:, :, self.stride[0] - 1 :]
|
||||
|
||||
hidden_states = self.conv(hidden_states)
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width
|
||||
@@ -240,15 +188,12 @@ class LTXVideoUpsampler3d(nn.Module):
|
||||
hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
hidden_states = hidden_states[:, :, self.stride[0] - 1 :]
|
||||
|
||||
if self.residual:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXVideoDownBlock3D(nn.Module):
|
||||
class LTXDownBlock3D(nn.Module):
|
||||
r"""
|
||||
Down block used in the LTXVideo model.
|
||||
Down block used in the LTX model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`):
|
||||
@@ -290,7 +235,7 @@ class LTXVideoDownBlock3D(nn.Module):
|
||||
resnets = []
|
||||
for _ in range(num_layers):
|
||||
resnets.append(
|
||||
LTXVideoResnetBlock3d(
|
||||
LTXResnetBlock3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
dropout=dropout,
|
||||
@@ -305,7 +250,7 @@ class LTXVideoDownBlock3D(nn.Module):
|
||||
if spatio_temporal_scale:
|
||||
self.downsamplers = nn.ModuleList(
|
||||
[
|
||||
LTXVideoCausalConv3d(
|
||||
LTXCausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
kernel_size=3,
|
||||
@@ -317,7 +262,7 @@ class LTXVideoDownBlock3D(nn.Module):
|
||||
|
||||
self.conv_out = None
|
||||
if in_channels != out_channels:
|
||||
self.conv_out = LTXVideoResnetBlock3d(
|
||||
self.conv_out = LTXResnetBlock3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
dropout=dropout,
|
||||
@@ -328,12 +273,7 @@ class LTXVideoDownBlock3D(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> torch.Tensor:
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
r"""Forward method of the `LTXDownBlock3D` class."""
|
||||
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
@@ -345,26 +285,24 @@ class LTXVideoDownBlock3D(nn.Module):
|
||||
|
||||
return create_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, generator
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, generator)
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
if self.conv_out is not None:
|
||||
hidden_states = self.conv_out(hidden_states, temb, generator)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d
|
||||
class LTXVideoMidBlock3d(nn.Module):
|
||||
class LTXMidBlock3d(nn.Module):
|
||||
r"""
|
||||
A middle block used in the LTXVideo model.
|
||||
A middle block used in the LTX model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`):
|
||||
@@ -391,51 +329,28 @@ class LTXVideoMidBlock3d(nn.Module):
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_act_fn: str = "swish",
|
||||
is_causal: bool = True,
|
||||
inject_noise: bool = False,
|
||||
timestep_conditioning: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.time_embedder = None
|
||||
if timestep_conditioning:
|
||||
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0)
|
||||
|
||||
resnets = []
|
||||
for _ in range(num_layers):
|
||||
resnets.append(
|
||||
LTXVideoResnetBlock3d(
|
||||
LTXResnetBlock3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
dropout=dropout,
|
||||
eps=resnet_eps,
|
||||
non_linearity=resnet_act_fn,
|
||||
is_causal=is_causal,
|
||||
inject_noise=inject_noise,
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
)
|
||||
)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> torch.Tensor:
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
r"""Forward method of the `LTXMidBlock3D` class."""
|
||||
|
||||
if self.time_embedder is not None:
|
||||
temb = self.time_embedder(
|
||||
timestep=temb.flatten(),
|
||||
resolution=None,
|
||||
aspect_ratio=None,
|
||||
batch_size=hidden_states.size(0),
|
||||
hidden_dtype=hidden_states.dtype,
|
||||
)
|
||||
temb = temb.view(hidden_states.size(0), -1, 1, 1, 1)
|
||||
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
@@ -445,18 +360,16 @@ class LTXVideoMidBlock3d(nn.Module):
|
||||
|
||||
return create_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, generator
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, generator)
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXVideoUpBlock3d(nn.Module):
|
||||
class LTXUpBlock3d(nn.Module):
|
||||
r"""
|
||||
Up block used in the LTXVideo model.
|
||||
Up block used in the LTX model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`):
|
||||
@@ -490,82 +403,45 @@ class LTXVideoUpBlock3d(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
spatio_temporal_scale: bool = True,
|
||||
is_causal: bool = True,
|
||||
inject_noise: bool = False,
|
||||
timestep_conditioning: bool = False,
|
||||
upsample_residual: bool = False,
|
||||
upscale_factor: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
out_channels = out_channels or in_channels
|
||||
|
||||
self.time_embedder = None
|
||||
if timestep_conditioning:
|
||||
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0)
|
||||
|
||||
self.conv_in = None
|
||||
if in_channels != out_channels:
|
||||
self.conv_in = LTXVideoResnetBlock3d(
|
||||
self.conv_in = LTXResnetBlock3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
dropout=dropout,
|
||||
eps=resnet_eps,
|
||||
non_linearity=resnet_act_fn,
|
||||
is_causal=is_causal,
|
||||
inject_noise=inject_noise,
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
)
|
||||
|
||||
self.upsamplers = None
|
||||
if spatio_temporal_scale:
|
||||
self.upsamplers = nn.ModuleList(
|
||||
[
|
||||
LTXVideoUpsampler3d(
|
||||
out_channels * upscale_factor,
|
||||
stride=(2, 2, 2),
|
||||
is_causal=is_causal,
|
||||
residual=upsample_residual,
|
||||
upscale_factor=upscale_factor,
|
||||
)
|
||||
]
|
||||
)
|
||||
self.upsamplers = nn.ModuleList([LTXUpsampler3d(out_channels, stride=(2, 2, 2), is_causal=is_causal)])
|
||||
|
||||
resnets = []
|
||||
for _ in range(num_layers):
|
||||
resnets.append(
|
||||
LTXVideoResnetBlock3d(
|
||||
LTXResnetBlock3d(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
dropout=dropout,
|
||||
eps=resnet_eps,
|
||||
non_linearity=resnet_act_fn,
|
||||
is_causal=is_causal,
|
||||
inject_noise=inject_noise,
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
)
|
||||
)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> torch.Tensor:
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if self.conv_in is not None:
|
||||
hidden_states = self.conv_in(hidden_states, temb, generator)
|
||||
|
||||
if self.time_embedder is not None:
|
||||
temb = self.time_embedder(
|
||||
timestep=temb.flatten(),
|
||||
resolution=None,
|
||||
aspect_ratio=None,
|
||||
batch_size=hidden_states.size(0),
|
||||
hidden_dtype=hidden_states.dtype,
|
||||
)
|
||||
temb = temb.view(hidden_states.size(0), -1, 1, 1, 1)
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
@@ -580,18 +456,16 @@ class LTXVideoUpBlock3d(nn.Module):
|
||||
|
||||
return create_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, generator
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, generator)
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXVideoEncoder3d(nn.Module):
|
||||
class LTXEncoder3d(nn.Module):
|
||||
r"""
|
||||
The `LTXVideoEncoder3d` layer of a variational autoencoder that encodes input video samples to its latent
|
||||
The `LTXEncoder3D` layer of a variational autoencoder that encodes input video samples to its latent
|
||||
representation.
|
||||
|
||||
Args:
|
||||
@@ -635,7 +509,7 @@ class LTXVideoEncoder3d(nn.Module):
|
||||
|
||||
output_channel = block_out_channels[0]
|
||||
|
||||
self.conv_in = LTXVideoCausalConv3d(
|
||||
self.conv_in = LTXCausalConv3d(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=output_channel,
|
||||
kernel_size=3,
|
||||
@@ -650,7 +524,7 @@ class LTXVideoEncoder3d(nn.Module):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i]
|
||||
|
||||
down_block = LTXVideoDownBlock3D(
|
||||
down_block = LTXDownBlock3D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
num_layers=layers_per_block[i],
|
||||
@@ -662,7 +536,7 @@ class LTXVideoEncoder3d(nn.Module):
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid block
|
||||
self.mid_block = LTXVideoMidBlock3d(
|
||||
self.mid_block = LTXMidBlock3d(
|
||||
in_channels=output_channel,
|
||||
num_layers=layers_per_block[-1],
|
||||
resnet_eps=resnet_norm_eps,
|
||||
@@ -672,14 +546,14 @@ class LTXVideoEncoder3d(nn.Module):
|
||||
# out
|
||||
self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = LTXVideoCausalConv3d(
|
||||
self.conv_out = LTXCausalConv3d(
|
||||
in_channels=output_channel, out_channels=out_channels + 1, kernel_size=3, stride=1, is_causal=is_causal
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
r"""The forward method of the `LTXVideoEncoder3d` class."""
|
||||
r"""The forward method of the `LTXEncoder3D` class."""
|
||||
|
||||
p = self.patch_size
|
||||
p_t = self.patch_size_t
|
||||
@@ -725,10 +599,9 @@ class LTXVideoEncoder3d(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXVideoDecoder3d(nn.Module):
|
||||
class LTXDecoder3d(nn.Module):
|
||||
r"""
|
||||
The `LTXVideoDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output
|
||||
sample.
|
||||
The `LTXDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output sample.
|
||||
|
||||
Args:
|
||||
in_channels (`int`, defaults to 128):
|
||||
@@ -749,8 +622,6 @@ class LTXVideoDecoder3d(nn.Module):
|
||||
Epsilon value for ResNet normalization layers.
|
||||
is_causal (`bool`, defaults to `False`):
|
||||
Whether this layer behaves causally (future frames depend only on past frames) or not.
|
||||
timestep_conditioning (`bool`, defaults to `False`):
|
||||
Whether to condition the model on timesteps.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -764,10 +635,6 @@ class LTXVideoDecoder3d(nn.Module):
|
||||
patch_size_t: int = 1,
|
||||
resnet_norm_eps: float = 1e-6,
|
||||
is_causal: bool = False,
|
||||
inject_noise: Tuple[bool, ...] = (False, False, False, False),
|
||||
timestep_conditioning: bool = False,
|
||||
upsample_residual: Tuple[bool, ...] = (False, False, False, False),
|
||||
upsample_factor: Tuple[bool, ...] = (1, 1, 1, 1),
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -778,42 +645,30 @@ class LTXVideoDecoder3d(nn.Module):
|
||||
block_out_channels = tuple(reversed(block_out_channels))
|
||||
spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling))
|
||||
layers_per_block = tuple(reversed(layers_per_block))
|
||||
inject_noise = tuple(reversed(inject_noise))
|
||||
upsample_residual = tuple(reversed(upsample_residual))
|
||||
upsample_factor = tuple(reversed(upsample_factor))
|
||||
output_channel = block_out_channels[0]
|
||||
|
||||
self.conv_in = LTXVideoCausalConv3d(
|
||||
self.conv_in = LTXCausalConv3d(
|
||||
in_channels=in_channels, out_channels=output_channel, kernel_size=3, stride=1, is_causal=is_causal
|
||||
)
|
||||
|
||||
self.mid_block = LTXVideoMidBlock3d(
|
||||
in_channels=output_channel,
|
||||
num_layers=layers_per_block[0],
|
||||
resnet_eps=resnet_norm_eps,
|
||||
is_causal=is_causal,
|
||||
inject_noise=inject_noise[0],
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
self.mid_block = LTXMidBlock3d(
|
||||
in_channels=output_channel, num_layers=layers_per_block[0], resnet_eps=resnet_norm_eps, is_causal=is_causal
|
||||
)
|
||||
|
||||
# up blocks
|
||||
num_block_out_channels = len(block_out_channels)
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
for i in range(num_block_out_channels):
|
||||
input_channel = output_channel // upsample_factor[i]
|
||||
output_channel = block_out_channels[i] // upsample_factor[i]
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
|
||||
up_block = LTXVideoUpBlock3d(
|
||||
up_block = LTXUpBlock3d(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
num_layers=layers_per_block[i + 1],
|
||||
resnet_eps=resnet_norm_eps,
|
||||
spatio_temporal_scale=spatio_temporal_scaling[i],
|
||||
is_causal=is_causal,
|
||||
inject_noise=inject_noise[i + 1],
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
upsample_residual=upsample_residual[i],
|
||||
upscale_factor=upsample_factor[i],
|
||||
)
|
||||
|
||||
self.up_blocks.append(up_block)
|
||||
@@ -821,20 +676,13 @@ class LTXVideoDecoder3d(nn.Module):
|
||||
# out
|
||||
self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = LTXVideoCausalConv3d(
|
||||
self.conv_out = LTXCausalConv3d(
|
||||
in_channels=output_channel, out_channels=self.out_channels, kernel_size=3, stride=1, is_causal=is_causal
|
||||
)
|
||||
|
||||
# timestep embedding
|
||||
self.time_embedder = None
|
||||
self.scale_shift_table = None
|
||||
if timestep_conditioning:
|
||||
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
@@ -845,33 +693,17 @@ class LTXVideoDecoder3d(nn.Module):
|
||||
|
||||
return create_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block), hidden_states, temb
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), hidden_states)
|
||||
|
||||
for up_block in self.up_blocks:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), hidden_states)
|
||||
else:
|
||||
hidden_states = self.mid_block(hidden_states, temb)
|
||||
hidden_states = self.mid_block(hidden_states)
|
||||
|
||||
for up_block in self.up_blocks:
|
||||
hidden_states = up_block(hidden_states, temb)
|
||||
hidden_states = up_block(hidden_states)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
|
||||
if self.time_embedder is not None:
|
||||
temb = self.time_embedder(
|
||||
timestep=temb.flatten(),
|
||||
resolution=None,
|
||||
aspect_ratio=None,
|
||||
batch_size=hidden_states.size(0),
|
||||
hidden_dtype=hidden_states.dtype,
|
||||
)
|
||||
temb = temb.view(hidden_states.size(0), -1, 1, 1, 1).unflatten(1, (2, -1))
|
||||
temb = temb + self.scale_shift_table[None, ..., None, None, None]
|
||||
shift, scale = temb.unbind(dim=1)
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
@@ -934,15 +766,8 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
out_channels: int = 3,
|
||||
latent_channels: int = 128,
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
|
||||
decoder_layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
|
||||
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
|
||||
decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
|
||||
decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False, False),
|
||||
upsample_residual: Tuple[bool, ...] = (False, False, False, False),
|
||||
upsample_factor: Tuple[int, ...] = (1, 1, 1, 1),
|
||||
timestep_conditioning: bool = False,
|
||||
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
|
||||
patch_size: int = 4,
|
||||
patch_size_t: int = 1,
|
||||
resnet_norm_eps: float = 1e-6,
|
||||
@@ -952,7 +777,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.encoder = LTXVideoEncoder3d(
|
||||
self.encoder = LTXEncoder3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=latent_channels,
|
||||
block_out_channels=block_out_channels,
|
||||
@@ -963,20 +788,16 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
resnet_norm_eps=resnet_norm_eps,
|
||||
is_causal=encoder_causal,
|
||||
)
|
||||
self.decoder = LTXVideoDecoder3d(
|
||||
self.decoder = LTXDecoder3d(
|
||||
in_channels=latent_channels,
|
||||
out_channels=out_channels,
|
||||
block_out_channels=decoder_block_out_channels,
|
||||
spatio_temporal_scaling=decoder_spatio_temporal_scaling,
|
||||
layers_per_block=decoder_layers_per_block,
|
||||
block_out_channels=block_out_channels,
|
||||
spatio_temporal_scaling=spatio_temporal_scaling,
|
||||
layers_per_block=layers_per_block,
|
||||
patch_size=patch_size,
|
||||
patch_size_t=patch_size_t,
|
||||
resnet_norm_eps=resnet_norm_eps,
|
||||
is_causal=decoder_causal,
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
inject_noise=decoder_inject_noise,
|
||||
upsample_residual=upsample_residual,
|
||||
upsample_factor=upsample_factor,
|
||||
)
|
||||
|
||||
latents_mean = torch.zeros((latent_channels,), requires_grad=False)
|
||||
@@ -1016,7 +837,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
self.tile_sample_stride_width = 448
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (LTXVideoEncoder3d, LTXVideoDecoder3d)):
|
||||
if isinstance(module, (LTXEncoder3d, LTXDecoder3d)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def enable_tiling(
|
||||
@@ -1115,15 +936,13 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
return (posterior,)
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(
|
||||
self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
batch_size, num_channels, num_frames, height, width = z.shape
|
||||
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
||||
tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
|
||||
|
||||
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
|
||||
return self.tiled_decode(z, temb, return_dict=return_dict)
|
||||
return self.tiled_decode(z, return_dict=return_dict)
|
||||
|
||||
if self.use_framewise_decoding:
|
||||
# TODO(aryan): requires investigation
|
||||
@@ -1133,7 +952,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
|
||||
)
|
||||
else:
|
||||
dec = self.decoder(z, temb)
|
||||
dec = self.decoder(z)
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
@@ -1141,9 +960,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(
|
||||
self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
"""
|
||||
Decode a batch of images.
|
||||
|
||||
@@ -1158,15 +975,10 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
returned.
|
||||
"""
|
||||
if self.use_slicing and z.shape[0] > 1:
|
||||
if temb is not None:
|
||||
decoded_slices = [
|
||||
self._decode(z_slice, t_slice).sample for z_slice, t_slice in (z.split(1), temb.split(1))
|
||||
]
|
||||
else:
|
||||
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
||||
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
else:
|
||||
decoded = self._decode(z, temb).sample
|
||||
decoded = self._decode(z).sample
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
@@ -1248,9 +1060,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
|
||||
return enc
|
||||
|
||||
def tiled_decode(
|
||||
self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
r"""
|
||||
Decode a batch of images using a tiled decoder.
|
||||
|
||||
@@ -1291,9 +1101,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
|
||||
)
|
||||
else:
|
||||
time = self.decoder(
|
||||
z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb
|
||||
)
|
||||
time = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width])
|
||||
|
||||
row.append(time)
|
||||
rows.append(row)
|
||||
@@ -1321,7 +1129,6 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
@@ -1332,7 +1139,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
z = posterior.sample(generator=generator)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z, temb)
|
||||
dec = self.decode(z)
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
return dec
|
||||
|
||||
@@ -748,10 +748,10 @@ class CogVideoXPatchEmbed(nn.Module):
|
||||
pos_embedding = self._get_positional_embeddings(
|
||||
height, width, pre_time_compression_frames, device=embeds.device
|
||||
)
|
||||
pos_embedding = pos_embedding.to(dtype=embeds.dtype)
|
||||
else:
|
||||
pos_embedding = self.pos_embedding
|
||||
|
||||
pos_embedding = pos_embedding.to(dtype=embeds.dtype)
|
||||
embeds = embeds + pos_embedding
|
||||
|
||||
return embeds
|
||||
|
||||
@@ -1,227 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# Reference: https://github.com/huggingface/accelerate/blob/ba7ab93f5e688466ea56908ea3b056fae2f9a023/src/accelerate/hooks.py
|
||||
class ModelHook:
|
||||
r"""
|
||||
A hook that contains callbacks to be executed just before and after the forward method of a model.
|
||||
"""
|
||||
|
||||
_is_stateful = False
|
||||
|
||||
def init_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
r"""
|
||||
Hook that is executed when a model is initialized.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module attached to this hook.
|
||||
"""
|
||||
return module
|
||||
|
||||
def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
|
||||
r"""
|
||||
Hook that is executed just before the forward method of the model.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module whose forward pass will be executed just after this event.
|
||||
args (`Tuple[Any]`):
|
||||
The positional arguments passed to the module.
|
||||
kwargs (`Dict[Str, Any]`):
|
||||
The keyword arguments passed to the module.
|
||||
Returns:
|
||||
`Tuple[Tuple[Any], Dict[Str, Any]]`:
|
||||
A tuple with the treated `args` and `kwargs`.
|
||||
"""
|
||||
return args, kwargs
|
||||
|
||||
def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
|
||||
r"""
|
||||
Hook that is executed just after the forward method of the model.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module whose forward pass been executed just before this event.
|
||||
output (`Any`):
|
||||
The output of the module.
|
||||
Returns:
|
||||
`Any`: The processed `output`.
|
||||
"""
|
||||
return output
|
||||
|
||||
def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
r"""
|
||||
Hook that is executed when the hook is detached from a module.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module detached from this hook.
|
||||
"""
|
||||
return module
|
||||
|
||||
def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
if self._is_stateful:
|
||||
raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
|
||||
return module
|
||||
|
||||
|
||||
class SequentialHook(ModelHook):
|
||||
r"""A hook that can contain several hooks and iterates through them at each event."""
|
||||
|
||||
def __init__(self, *hooks):
|
||||
self.hooks = hooks
|
||||
|
||||
def init_hook(self, module):
|
||||
for hook in self.hooks:
|
||||
module = hook.init_hook(module)
|
||||
return module
|
||||
|
||||
def pre_forward(self, module, *args, **kwargs):
|
||||
for hook in self.hooks:
|
||||
args, kwargs = hook.pre_forward(module, *args, **kwargs)
|
||||
return args, kwargs
|
||||
|
||||
def post_forward(self, module, output):
|
||||
for hook in self.hooks:
|
||||
output = hook.post_forward(module, output)
|
||||
return output
|
||||
|
||||
def detach_hook(self, module):
|
||||
for hook in self.hooks:
|
||||
module = hook.detach_hook(module)
|
||||
return module
|
||||
|
||||
def reset_state(self, module):
|
||||
for hook in self.hooks:
|
||||
if hook._is_stateful:
|
||||
hook.reset_state(module)
|
||||
return module
|
||||
|
||||
|
||||
def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False) -> torch.nn.Module:
|
||||
r"""
|
||||
Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove
|
||||
this behavior and restore the original `forward` method, use `remove_hook_from_module`.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
If the module already contains a hook, this will replace it with the new hook passed by default. To chain two hooks
|
||||
together, pass `append=True`, so it chains the current and new hook into an instance of the `SequentialHook` class.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module to attach a hook to.
|
||||
hook (`ModelHook`):
|
||||
The hook to attach.
|
||||
append (`bool`, *optional*, defaults to `False`):
|
||||
Whether the hook should be chained with an existing one (if module already contains a hook) or not.
|
||||
Returns:
|
||||
`torch.nn.Module`:
|
||||
The same module, with the hook attached (the module is modified in place, so the result can be discarded).
|
||||
"""
|
||||
original_hook = hook
|
||||
|
||||
if append and getattr(module, "_diffusers_hook", None) is not None:
|
||||
old_hook = module._diffusers_hook
|
||||
remove_hook_from_module(module)
|
||||
hook = SequentialHook(old_hook, hook)
|
||||
|
||||
if hasattr(module, "_diffusers_hook") and hasattr(module, "_old_forward"):
|
||||
# If we already put some hook on this module, we replace it with the new one.
|
||||
old_forward = module._old_forward
|
||||
else:
|
||||
old_forward = module.forward
|
||||
module._old_forward = old_forward
|
||||
|
||||
module = hook.init_hook(module)
|
||||
module._diffusers_hook = hook
|
||||
|
||||
if hasattr(original_hook, "new_forward"):
|
||||
new_forward = original_hook.new_forward
|
||||
else:
|
||||
|
||||
def new_forward(module, *args, **kwargs):
|
||||
args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs)
|
||||
output = module._old_forward(*args, **kwargs)
|
||||
return module._diffusers_hook.post_forward(module, output)
|
||||
|
||||
# Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.
|
||||
# Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409
|
||||
if "GraphModuleImpl" in str(type(module)):
|
||||
module.__class__.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward)
|
||||
else:
|
||||
module.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward)
|
||||
|
||||
return module
|
||||
|
||||
|
||||
def remove_hook_from_module(module: torch.nn.Module, recurse: bool = False) -> torch.nn.Module:
|
||||
"""
|
||||
Removes any hook attached to a module via `add_hook_to_module`.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module to attach a hook to.
|
||||
recurse (`bool`, defaults to `False`):
|
||||
Whether to remove the hooks recursively
|
||||
Returns:
|
||||
`torch.nn.Module`:
|
||||
The same module, with the hook detached (the module is modified in place, so the result can be discarded).
|
||||
"""
|
||||
|
||||
if hasattr(module, "_diffusers_hook"):
|
||||
module._diffusers_hook.detach_hook(module)
|
||||
delattr(module, "_diffusers_hook")
|
||||
|
||||
if hasattr(module, "_old_forward"):
|
||||
# Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.
|
||||
# Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409
|
||||
if "GraphModuleImpl" in str(type(module)):
|
||||
module.__class__.forward = module._old_forward
|
||||
else:
|
||||
module.forward = module._old_forward
|
||||
delattr(module, "_old_forward")
|
||||
|
||||
if recurse:
|
||||
for child in module.children():
|
||||
remove_hook_from_module(child, recurse)
|
||||
|
||||
return module
|
||||
|
||||
|
||||
def reset_stateful_hooks(module: torch.nn.Module, recurse: bool = False):
|
||||
"""
|
||||
Resets the state of all stateful hooks attached to a module.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module to reset the stateful hooks from.
|
||||
"""
|
||||
if hasattr(module, "_diffusers_hook") and (
|
||||
module._diffusers_hook._is_stateful or isinstance(module._diffusers_hook, SequentialHook)
|
||||
):
|
||||
module._diffusers_hook.reset_state(module)
|
||||
|
||||
if recurse:
|
||||
for child in module.children():
|
||||
reset_stateful_hooks(child, recurse)
|
||||
@@ -228,7 +228,7 @@ def load_model_dict_into_meta(
|
||||
else:
|
||||
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
|
||||
raise ValueError(
|
||||
f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name].shape}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
|
||||
f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
|
||||
)
|
||||
|
||||
if is_quantized and (
|
||||
|
||||
@@ -718,9 +718,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
hf_quantizer = None
|
||||
|
||||
if hf_quantizer is not None:
|
||||
if device_map is not None:
|
||||
is_bnb_quantization_method = hf_quantizer.quantization_config.quant_method.value == "bitsandbytes"
|
||||
if is_bnb_quantization_method and device_map is not None:
|
||||
raise NotImplementedError(
|
||||
"Currently, providing `device_map` is not supported for quantized models. Providing `device_map` as an input will be added in the future."
|
||||
"Currently, `device_map` is automatically inferred for quantized bitsandbytes models. Support for providing `device_map` as an input will be added in the future."
|
||||
)
|
||||
|
||||
hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map)
|
||||
@@ -819,8 +820,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
revision=revision,
|
||||
subfolder=subfolder or "",
|
||||
)
|
||||
# TODO: https://github.com/huggingface/diffusers/issues/10013
|
||||
if hf_quantizer is not None:
|
||||
if hf_quantizer is not None and is_bnb_quantization_method:
|
||||
model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
|
||||
logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
|
||||
is_sharded = False
|
||||
|
||||
@@ -242,7 +242,6 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
patch_size: int = 1,
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-6,
|
||||
interpolation_scale: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -250,14 +249,14 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
# 1. Patch Embedding
|
||||
interpolation_scale = interpolation_scale if interpolation_scale is not None else max(sample_size // 64, 1)
|
||||
self.patch_embed = PatchEmbed(
|
||||
height=sample_size,
|
||||
width=sample_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dim=inner_dim,
|
||||
interpolation_scale=interpolation_scale,
|
||||
interpolation_scale=None,
|
||||
pos_embed_type=None,
|
||||
)
|
||||
|
||||
# 2. Additional condition embeddings
|
||||
|
||||
@@ -18,8 +18,6 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from diffusers.loaders import FromOriginalModelMixin
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
@@ -502,7 +500,7 @@ class HunyuanVideoTransformerBlock(nn.Module):
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
r"""
|
||||
A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ from ..normalization import AdaLayerNormSingle, RMSNorm
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class LTXVideoAttentionProcessor2_0:
|
||||
class LTXAttentionProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
||||
used in the LTX model. It applies a normalization layer and rotary embedding on the query and key vector.
|
||||
@@ -44,7 +44,7 @@ class LTXVideoAttentionProcessor2_0:
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"LTXVideoAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
"LTXAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
@@ -92,7 +92,7 @@ class LTXVideoAttentionProcessor2_0:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXVideoRotaryPosEmbed(nn.Module):
|
||||
class LTXRotaryPosEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
@@ -164,7 +164,7 @@ class LTXVideoRotaryPosEmbed(nn.Module):
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class LTXVideoTransformerBlock(nn.Module):
|
||||
class LTXTransformerBlock(nn.Module):
|
||||
r"""
|
||||
Transformer block used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
|
||||
|
||||
@@ -208,7 +208,7 @@ class LTXVideoTransformerBlock(nn.Module):
|
||||
cross_attention_dim=None,
|
||||
out_bias=attention_out_bias,
|
||||
qk_norm=qk_norm,
|
||||
processor=LTXVideoAttentionProcessor2_0(),
|
||||
processor=LTXAttentionProcessor2_0(),
|
||||
)
|
||||
|
||||
self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
@@ -221,7 +221,7 @@ class LTXVideoTransformerBlock(nn.Module):
|
||||
bias=attention_bias,
|
||||
out_bias=attention_out_bias,
|
||||
qk_norm=qk_norm,
|
||||
processor=LTXVideoAttentionProcessor2_0(),
|
||||
processor=LTXAttentionProcessor2_0(),
|
||||
)
|
||||
|
||||
self.ff = FeedForward(dim, activation_fn=activation_fn)
|
||||
@@ -327,7 +327,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
|
||||
|
||||
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
||||
|
||||
self.rope = LTXVideoRotaryPosEmbed(
|
||||
self.rope = LTXRotaryPosEmbed(
|
||||
dim=inner_dim,
|
||||
base_num_frames=20,
|
||||
base_height=2048,
|
||||
@@ -339,7 +339,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
LTXVideoTransformerBlock(
|
||||
LTXTransformerBlock(
|
||||
dim=inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
|
||||
@@ -21,18 +21,11 @@ from ...image_processor import VaeImageProcessor
|
||||
from ...models import AuraFlowTransformer2DModel, AutoencoderKL
|
||||
from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils import logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@@ -571,9 +564,6 @@ class AuraFlowPipeline(DiffusionPipeline):
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
else:
|
||||
|
||||
@@ -39,7 +39,7 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
|
||||
>>> from diffusers.utils import export_to_video
|
||||
|
||||
>>> model_id = "hunyuanvideo-community/HunyuanVideo"
|
||||
>>> model_id = "tencent/HunyuanVideo"
|
||||
>>> transformer = HunyuanVideoTransformer3DModel.from_pretrained(
|
||||
... model_id, subfolder="transformer", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
|
||||
@@ -193,15 +193,15 @@ class KandinskyCombinedPipeline(DiffusionPipeline):
|
||||
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
|
||||
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗
|
||||
Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a
|
||||
GPU only when their specific submodule's `forward` method is called. Offloading happens on a submodule basis.
|
||||
Memory savings are higher than using `enable_model_cpu_offload`, but performance is lower.
|
||||
"""
|
||||
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
|
||||
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
|
||||
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
|
||||
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
|
||||
|
||||
def progress_bar(self, iterable=None, total=None):
|
||||
self.prior_pipe.progress_bar(iterable=iterable, total=total)
|
||||
@@ -411,7 +411,7 @@ class KandinskyImg2ImgCombinedPipeline(DiffusionPipeline):
|
||||
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
|
||||
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
||||
@@ -419,8 +419,8 @@ class KandinskyImg2ImgCombinedPipeline(DiffusionPipeline):
|
||||
Note that offloading happens on a submodule basis. Memory savings are higher than with
|
||||
`enable_model_cpu_offload`, but performance is lower.
|
||||
"""
|
||||
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
|
||||
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
|
||||
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
|
||||
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
|
||||
|
||||
def progress_bar(self, iterable=None, total=None):
|
||||
self.prior_pipe.progress_bar(iterable=iterable, total=total)
|
||||
@@ -652,7 +652,7 @@ class KandinskyInpaintCombinedPipeline(DiffusionPipeline):
|
||||
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
|
||||
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
||||
@@ -660,8 +660,8 @@ class KandinskyInpaintCombinedPipeline(DiffusionPipeline):
|
||||
Note that offloading happens on a submodule basis. Memory savings are higher than with
|
||||
`enable_model_cpu_offload`, but performance is lower.
|
||||
"""
|
||||
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
|
||||
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
|
||||
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
|
||||
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
|
||||
|
||||
def progress_bar(self, iterable=None, total=None):
|
||||
self.prior_pipe.progress_bar(iterable=iterable, total=total)
|
||||
|
||||
@@ -511,8 +511,6 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
decode_timestep: Union[float, List[float]] = 0.0,
|
||||
decode_noise_scale: Optional[Union[float, List[float]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
@@ -565,10 +563,6 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
|
||||
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
||||
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated attention mask for negative text embeddings.
|
||||
decode_timestep (`float`, defaults to `0.0`):
|
||||
The timestep at which generated video is decoded.
|
||||
decode_noise_scale (`float`, defaults to `None`):
|
||||
The interpolation factor between random noise and denoised latents at the decode timestep.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
@@ -759,25 +753,7 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
latents = latents.to(prompt_embeds.dtype)
|
||||
|
||||
if not self.vae.config.timestep_conditioning:
|
||||
timestep = None
|
||||
else:
|
||||
noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype)
|
||||
if not isinstance(decode_timestep, list):
|
||||
decode_timestep = [decode_timestep] * batch_size
|
||||
if decode_noise_scale is None:
|
||||
decode_noise_scale = decode_timestep
|
||||
elif not isinstance(decode_noise_scale, list):
|
||||
decode_noise_scale = [decode_noise_scale] * batch_size
|
||||
|
||||
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
|
||||
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
|
||||
:, None, None, None, None
|
||||
]
|
||||
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
|
||||
|
||||
video = self.vae.decode(latents, timestep, return_dict=False)[0]
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
|
||||
@@ -571,8 +571,6 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
decode_timestep: Union[float, List[float]] = 0.0,
|
||||
decode_noise_scale: Optional[Union[float, List[float]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
@@ -627,10 +625,6 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
|
||||
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
||||
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated attention mask for negative text embeddings.
|
||||
decode_timestep (`float`, defaults to `0.0`):
|
||||
The timestep at which generated video is decoded.
|
||||
decode_noise_scale (`float`, defaults to `None`):
|
||||
The interpolation factor between random noise and denoised latents at the decode timestep.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
@@ -855,25 +849,7 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
latents = latents.to(prompt_embeds.dtype)
|
||||
|
||||
if not self.vae.config.timestep_conditioning:
|
||||
timestep = None
|
||||
else:
|
||||
noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype)
|
||||
if not isinstance(decode_timestep, list):
|
||||
decode_timestep = [decode_timestep] * batch_size
|
||||
if decode_noise_scale is None:
|
||||
decode_noise_scale = decode_timestep
|
||||
elif not isinstance(decode_noise_scale, list):
|
||||
decode_noise_scale = [decode_noise_scale] * batch_size
|
||||
|
||||
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
|
||||
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
|
||||
:, None, None, None, None
|
||||
]
|
||||
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
|
||||
|
||||
video = self.vae.decode(latents, timestep, return_dict=False)[0]
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
|
||||
@@ -59,13 +59,13 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> from diffusers import SanaPAGPipeline
|
||||
|
||||
>>> pipe = SanaPAGPipeline.from_pretrained(
|
||||
... "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
|
||||
... "Efficient-Large-Model/Sana_1600M_1024px_diffusers",
|
||||
... pag_applied_layers=["transformer_blocks.8"],
|
||||
... torch_dtype=torch.float32,
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
>>> pipe.text_encoder.to(torch.bfloat16)
|
||||
>>> pipe.transformer = pipe.transformer.to(torch.bfloat16)
|
||||
>>> pipe.transformer = pipe.transformer.to(torch.float16)
|
||||
|
||||
>>> image = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"')[0]
|
||||
>>> image[0].save("output.png")
|
||||
|
||||
@@ -31,7 +31,6 @@ from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
is_bs4_available,
|
||||
is_ftfy_available,
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
scale_lora_layers,
|
||||
@@ -47,13 +46,6 @@ from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN
|
||||
from .pipeline_output import SanaPipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
if is_bs4_available():
|
||||
@@ -70,11 +62,11 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> from diffusers import SanaPipeline
|
||||
|
||||
>>> pipe = SanaPipeline.from_pretrained(
|
||||
... "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", torch_dtype=torch.float32
|
||||
... "Efficient-Large-Model/Sana_1600M_1024px_diffusers", torch_dtype=torch.float32
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
>>> pipe.text_encoder.to(torch.bfloat16)
|
||||
>>> pipe.transformer = pipe.transformer.to(torch.bfloat16)
|
||||
>>> pipe.transformer = pipe.transformer.to(torch.float16)
|
||||
|
||||
>>> image = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"')[0]
|
||||
>>> image[0].save("output.png")
|
||||
@@ -872,9 +864,6 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
else:
|
||||
|
||||
@@ -226,21 +226,12 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||
)
|
||||
latent_channels = self.vae.config.latent_channels if hasattr(self, "vae") and self.vae is not None else 16
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor, vae_latent_channels=latent_channels
|
||||
)
|
||||
self.tokenizer_max_length = (
|
||||
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
||||
)
|
||||
self.default_sample_size = (
|
||||
self.transformer.config.sample_size
|
||||
if hasattr(self, "transformer") and self.transformer is not None
|
||||
else 128
|
||||
vae_scale_factor=self.vae_scale_factor, vae_latent_channels=self.vae.config.latent_channels
|
||||
)
|
||||
self.tokenizer_max_length = self.tokenizer.model_max_length
|
||||
self.default_sample_size = self.transformer.config.sample_size
|
||||
self.patch_size = (
|
||||
self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2
|
||||
)
|
||||
|
||||
@@ -225,28 +225,19 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||
)
|
||||
latent_channels = self.vae.config.latent_channels if hasattr(self, "vae") and self.vae is not None else 16
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor, vae_latent_channels=latent_channels
|
||||
vae_scale_factor=self.vae_scale_factor, vae_latent_channels=self.vae.config.latent_channels
|
||||
)
|
||||
self.mask_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor,
|
||||
vae_latent_channels=latent_channels,
|
||||
vae_latent_channels=self.vae.config.latent_channels,
|
||||
do_normalize=False,
|
||||
do_binarize=True,
|
||||
do_convert_grayscale=True,
|
||||
)
|
||||
self.tokenizer_max_length = (
|
||||
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
||||
)
|
||||
self.default_sample_size = (
|
||||
self.transformer.config.sample_size
|
||||
if hasattr(self, "transformer") and self.transformer is not None
|
||||
else 128
|
||||
)
|
||||
self.tokenizer_max_length = self.tokenizer.model_max_length
|
||||
self.default_sample_size = self.transformer.config.sample_size
|
||||
self.patch_size = (
|
||||
self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2
|
||||
)
|
||||
|
||||
@@ -1,252 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..models import (
|
||||
FluxTransformer2DModel,
|
||||
HunyuanVideoTransformer3DModel,
|
||||
LTXVideoTransformer3DModel,
|
||||
LuminaNextDiT2DModel,
|
||||
MochiTransformer3DModel,
|
||||
)
|
||||
from ..models.hooks import ModelHook, add_hook_to_module
|
||||
from ..utils import logging
|
||||
from .pipeline_utils import DiffusionPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Source: https://github.com/ali-vilab/TeaCache
|
||||
# TODO(aryan): Implement methods to calibrate and compute polynomial coefficients on-the-fly, and export to file for re-use.
|
||||
# fmt: off
|
||||
_MODEL_TO_POLY_COEFFICIENTS = {
|
||||
FluxTransformer2DModel: [4.98651651e02, -2.83781631e02, 5.58554382e01, -3.82021401e00, 2.64230861e-01],
|
||||
HunyuanVideoTransformer3DModel: [7.33226126e02, -4.01131952e02, 6.75869174e01, -3.14987800e00, 9.61237896e-02],
|
||||
LTXVideoTransformer3DModel: [2.14700694e01, -1.28016453e01, 2.31279151e00, 7.92487521e-01, 9.69274326e-03],
|
||||
LuminaNextDiT2DModel: [393.76566581, -603.50993606, 209.10239044, -23.00726601, 0.86377344],
|
||||
MochiTransformer3DModel: [-3.51241319e03, 8.11675948e02, -6.09400215e01, 2.42429681e00, 3.05291719e-03],
|
||||
}
|
||||
# fmt: on
|
||||
|
||||
_MODEL_TO_1_POINT_5X_SPEEDUP_THRESHOLD = {
|
||||
FluxTransformer2DModel: 0.25,
|
||||
HunyuanVideoTransformer3DModel: 0.1,
|
||||
LTXVideoTransformer3DModel: 0.05,
|
||||
LuminaNextDiT2DModel: 0.2,
|
||||
MochiTransformer3DModel: 0.06,
|
||||
}
|
||||
|
||||
_MODEL_TO_TIMESTEP_MODULATED_LAYER_IDENTIFIER = {
|
||||
FluxTransformer2DModel: "transformer_blocks.0.norm1",
|
||||
}
|
||||
|
||||
_MODEL_TO_SKIP_END_LAYER_IDENTIFIER = {
|
||||
FluxTransformer2DModel: "norm_out",
|
||||
}
|
||||
|
||||
_DEFAULT_SKIP_LAYER_IDENTIFIERS = [
|
||||
"blocks",
|
||||
"transformer_blocks",
|
||||
"single_transformer_blocks",
|
||||
"temporal_transformer_blocks",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TeaCacheConfig:
|
||||
l1_threshold: Optional[float] = None
|
||||
|
||||
skip_layer_identifiers: List[str] = _DEFAULT_SKIP_LAYER_IDENTIFIERS
|
||||
|
||||
_polynomial_coefficients: Optional[List[float]] = None
|
||||
|
||||
|
||||
class TeaCacheDenoiserState:
|
||||
def __init__(self):
|
||||
self.iteration: int = 0
|
||||
self.accumulated_l1_difference: float = 0.0
|
||||
self.timestep_modulated_cache: torch.Tensor = None
|
||||
self.residual_cache: torch.Tensor = None
|
||||
self.should_skip_blocks: bool = False
|
||||
|
||||
def reset(self):
|
||||
self.iteration = 0
|
||||
self.accumulated_l1_difference = 0.0
|
||||
self.timestep_modulated_cache = None
|
||||
self.residual_cache = None
|
||||
|
||||
|
||||
def apply_teacache(
|
||||
pipeline: DiffusionPipeline, config: Optional[TeaCacheConfig] = None, denoiser: Optional[nn.Module] = None
|
||||
) -> None:
|
||||
r"""Applies [TeaCache](https://huggingface.co/papers/2411.19108) to a given pipeline or denoiser module.
|
||||
|
||||
Args:
|
||||
TODO
|
||||
"""
|
||||
|
||||
if config is None:
|
||||
logger.warning("No TeaCacheConfig provided. Using default configuration.")
|
||||
config = TeaCacheConfig()
|
||||
|
||||
if denoiser is None:
|
||||
denoiser = pipeline.transformer if hasattr(pipeline, "transformer") else pipeline.unet
|
||||
|
||||
if isinstance(denoiser, (_MODEL_TO_POLY_COEFFICIENTS.keys())):
|
||||
if config.l1_threshold is None:
|
||||
logger.info(
|
||||
f"No L1 threshold was provided for {type(denoiser)}. Using default threshold as provided in the TeaCache paper for 1.5x speedup. "
|
||||
f"For higher speedup, increase the threshold."
|
||||
)
|
||||
config.l1_threshold = _MODEL_TO_1_POINT_5X_SPEEDUP_THRESHOLD[type(denoiser)]
|
||||
if config.timestep_modulated_layer_identifier is None:
|
||||
logger.info(
|
||||
f"No timestep modulated layer identifier was provided for {type(denoiser)}. Using default identifier as provided in the TeaCache paper."
|
||||
)
|
||||
config.timestep_modulated_layer_identifier = _MODEL_TO_TIMESTEP_MODULATED_LAYER_IDENTIFIER[type(denoiser)]
|
||||
if config._polynomial_coefficients is None:
|
||||
logger.info(
|
||||
f"No polynomial coefficients were provided for {type(denoiser)}. Using default coefficients as provided in the TeaCache paper."
|
||||
)
|
||||
config._polynomial_coefficients = _MODEL_TO_POLY_COEFFICIENTS[type(denoiser)]
|
||||
else:
|
||||
if config.l1_threshold is None:
|
||||
raise ValueError(
|
||||
f"No L1 threshold was provided for {type(denoiser)}. Using TeaCache with this model is not supported "
|
||||
f"in Diffusers. Please provide the L1 threshold in the config by setting the `l1_threshold` attribute."
|
||||
)
|
||||
if config.timestep_modulated_layer_identifier is None:
|
||||
raise ValueError(
|
||||
f"No timestep modulated layer identifier was provided for {type(denoiser)}. Using TeaCache with this model is not supported "
|
||||
f"in Diffusers. Please provide the layer identifier in the config by setting the `timestep_modulated_layer_identifier` attribute."
|
||||
)
|
||||
if config._polynomial_coefficients is None:
|
||||
raise ValueError(
|
||||
f"No polynomial coefficients were provided for {type(denoiser)}. Using TeaCache with this model is not "
|
||||
f"supported in Diffusers. Please provide the polynomial coefficients in the config by setting the "
|
||||
f"`_polynomial_coefficients` attribute. Automatic calibration will be implemented in the future."
|
||||
)
|
||||
|
||||
timestep_modulated_layer_matches = list(
|
||||
{
|
||||
module
|
||||
for name, module in denoiser.named_modules()
|
||||
if re.match(config.timestep_modulated_layer_identifier, name)
|
||||
}
|
||||
)
|
||||
|
||||
if len(timestep_modulated_layer_matches) == 0:
|
||||
raise ValueError(
|
||||
f"No layer in the denoiser module matched the provided timestep modulated layer identifier: "
|
||||
f"{config.timestep_modulated_layer_identifier}. Please provide a valid layer identifier."
|
||||
)
|
||||
if len(timestep_modulated_layer_matches) > 1:
|
||||
logger.warning(
|
||||
f"Multiple layers in the denoiser module matched the provided timestep modulated layer identifier: "
|
||||
f"{config.timestep_modulated_layer_identifier}. Using the first match."
|
||||
)
|
||||
|
||||
denoiser_state = TeaCacheDenoiserState()
|
||||
|
||||
timestep_modulated_layer = timestep_modulated_layer_matches[0]
|
||||
hook = TimestepModulatedOutputCacheHook(denoiser_state, config.l1_threshold, config._polynomial_coefficients)
|
||||
add_hook_to_module(timestep_modulated_layer, hook, append=True)
|
||||
|
||||
skip_layer_identifiers = config.skip_layer_identifiers
|
||||
skip_layer_matches = list(
|
||||
{
|
||||
module
|
||||
for name, module in denoiser.named_modules()
|
||||
if any(re.match(identifier, name) for identifier in skip_layer_identifiers)
|
||||
}
|
||||
)
|
||||
|
||||
for skip_layer in skip_layer_matches:
|
||||
hook = DenoiserStateBasedSkipLayerHook(denoiser_state)
|
||||
add_hook_to_module(skip_layer, hook, append=True)
|
||||
|
||||
|
||||
class TimestepModulatedOutputCacheHook(ModelHook):
|
||||
# The denoiser hook will reset its state, so we don't have to handle it here
|
||||
_is_stateful = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
denoiser_state: TeaCacheDenoiserState,
|
||||
l1_threshold: float,
|
||||
polynomial_coefficients: List[float],
|
||||
) -> None:
|
||||
self.denoiser_state = denoiser_state
|
||||
self.l1_threshold = l1_threshold
|
||||
# TODO(aryan): implement torch equivalent
|
||||
self.rescale_fn = np.poly1d(polynomial_coefficients)
|
||||
|
||||
def post_forward(self, module, output):
|
||||
if isinstance(output, tuple):
|
||||
# This assumes that the first element of the output tuple is the timestep modulated noise output.
|
||||
# For Diffusers models, this is true. For models outside diffusers, users will have to ensure
|
||||
# that the first element of the output tuple is the timestep modulated noise output (seems to be
|
||||
# the case for most research model implementations).
|
||||
timestep_modulated_noise = output[0]
|
||||
elif torch.is_tensor(output):
|
||||
timestep_modulated_noise = output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected output to be a tensor or a tuple with first element as timestep modulated noise. "
|
||||
f"Got {type(output)} instead. Please ensure that the denoiser module returns the timestep "
|
||||
f"modulated noise output as the first element."
|
||||
)
|
||||
|
||||
if self.denoiser_state.timestep_modulated_cache is not None:
|
||||
l1_diff = (timestep_modulated_noise - self.denoiser_state.timestep_modulated_cache).abs().mean()
|
||||
normalized_l1_diff = l1_diff / self.denoiser_state.timestep_modulated_cache.abs().mean()
|
||||
rescaled_l1_diff = self.rescale_fn(normalized_l1_diff)
|
||||
self.denoiser_state.accumulated_l1_difference += rescaled_l1_diff
|
||||
|
||||
if self.denoiser_state.accumulated_l1_difference >= self.l1_threshold:
|
||||
self.denoiser_state.should_skip_blocks = True
|
||||
self.denoiser_state.accumulated_l1_difference = 0.0
|
||||
else:
|
||||
self.denoiser_state.should_skip_blocks = False
|
||||
|
||||
self.denoiser_state.timestep_modulated_cache = timestep_modulated_noise
|
||||
return output
|
||||
|
||||
|
||||
class DenoiserStateBasedSkipLayerHook(ModelHook):
|
||||
_is_stateful = False
|
||||
|
||||
def __init__(self, denoiser_state: TeaCacheDenoiserState) -> None:
|
||||
self.denoiser_state = denoiser_state
|
||||
|
||||
def new_forward(self, module, *args, **kwargs):
|
||||
args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs)
|
||||
|
||||
if not self.denoiser_state.should_skip_blocks:
|
||||
output = module._old_forward(*args, **kwargs)
|
||||
else:
|
||||
# Diffusers models either expect one output (hidden_states) or a tuple of two outputs (hidden_states, encoder_hidden_states).
|
||||
# Returning a tuple of None values handles both cases. It is okay to do because we are not going to be using these
|
||||
# anywhere if self.denoiser_state.should_skip_blocks is True.
|
||||
output = (None, None)
|
||||
|
||||
return module._diffusers_hook.post_forward(module, output)
|
||||
@@ -23,7 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||
|
||||
from packaging import version
|
||||
|
||||
from ...utils import get_module_from_name, is_torch_available, is_torch_version, is_torchao_available, logging
|
||||
from ...utils import get_module_from_name, is_torch_available, is_torchao_available, logging
|
||||
from ..base import DiffusersQuantizer
|
||||
|
||||
|
||||
@@ -35,28 +35,21 @@ if is_torch_available():
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
if is_torch_version(">=", "2.5"):
|
||||
SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = (
|
||||
# At the moment, only int8 is supported for integer quantization dtypes.
|
||||
# In Torch 2.6, int1-int7 will be introduced, so this can be visited in the future
|
||||
# to support more quantization methods, such as intx_weight_only.
|
||||
torch.int8,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e5m2,
|
||||
torch.uint1,
|
||||
torch.uint2,
|
||||
torch.uint3,
|
||||
torch.uint4,
|
||||
torch.uint5,
|
||||
torch.uint6,
|
||||
torch.uint7,
|
||||
)
|
||||
else:
|
||||
SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = (
|
||||
torch.int8,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e5m2,
|
||||
)
|
||||
SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = (
|
||||
# At the moment, only int8 is supported for integer quantization dtypes.
|
||||
# In Torch 2.6, int1-int7 will be introduced, so this can be visited in the future
|
||||
# to support more quantization methods, such as intx_weight_only.
|
||||
torch.int8,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e5m2,
|
||||
torch.uint1,
|
||||
torch.uint2,
|
||||
torch.uint3,
|
||||
torch.uint4,
|
||||
torch.uint5,
|
||||
torch.uint6,
|
||||
torch.uint7,
|
||||
)
|
||||
|
||||
if is_torchao_available():
|
||||
from torchao.quantization import quantize_
|
||||
@@ -100,11 +93,6 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
|
||||
raise ImportError(
|
||||
"Loading a TorchAO quantized model requires the torchao library. Please install with `pip install torchao`"
|
||||
)
|
||||
torchao_version = version.parse(importlib.metadata.version("torch"))
|
||||
if torchao_version < version.parse("0.7.0"):
|
||||
raise RuntimeError(
|
||||
f"The minimum required version of `torchao` is 0.7.0, but the current version is {torchao_version}. Please upgrade with `pip install -U torchao`."
|
||||
)
|
||||
|
||||
self.offload = False
|
||||
|
||||
@@ -132,7 +120,7 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
|
||||
def update_torch_dtype(self, torch_dtype):
|
||||
quant_type = self.quantization_config.quant_type
|
||||
|
||||
if quant_type.startswith("int") or quant_type.startswith("uint"):
|
||||
if quant_type.startswith("int"):
|
||||
if torch_dtype is not None and torch_dtype != torch.bfloat16:
|
||||
logger.warning(
|
||||
f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but "
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user