Compare commits
27 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e68c936f42 | |||
| dccc206e35 | |||
| 6f2ded53a1 | |||
| 6d2a80c14b | |||
| 219a8ab031 | |||
| 3a00e23f5a | |||
| 19fe63170c | |||
| 41381b1bb1 | |||
| bcada5bfaf | |||
| 4490e4cc44 | |||
| 27c1ac49b4 | |||
| 585c32b304 | |||
| ca5afaebca | |||
| 6c066f0e13 | |||
| fbb25a05be | |||
| fbc4c998ed | |||
| 56d2986d5d | |||
| a33ef355f6 | |||
| 85b7478fe9 | |||
| d1e6ffffad | |||
| 61c6eae207 | |||
| a076cd8e16 | |||
| 2b72beefe7 | |||
| 11bf2cf1d1 | |||
| 19921e9362 | |||
| 5aa4f1dc55 | |||
| 922e273e6b |
@@ -22,7 +22,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install -e .
|
||||
|
||||
@@ -35,7 +35,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
@@ -55,7 +55,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
|
||||
@@ -36,7 +36,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
@@ -56,7 +56,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
|
||||
@@ -22,7 +22,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install -e .
|
||||
|
||||
@@ -47,7 +47,7 @@ jobs:
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
python-version: "3.10"
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
|
||||
@@ -373,8 +373,6 @@
|
||||
title: QwenImageTransformer2DModel
|
||||
- local: api/models/sana_transformer2d
|
||||
title: SanaTransformer2DModel
|
||||
- local: api/models/sana_video_transformer3d
|
||||
title: SanaVideoTransformer3DModel
|
||||
- local: api/models/sd3_transformer2d
|
||||
title: SD3Transformer2DModel
|
||||
- local: api/models/skyreels_v2_transformer_3d
|
||||
@@ -531,6 +529,8 @@
|
||||
title: Kandinsky 2.2
|
||||
- local: api/pipelines/kandinsky3
|
||||
title: Kandinsky 3
|
||||
- local: api/pipelines/kandinsky5
|
||||
title: Kandinsky 5
|
||||
- local: api/pipelines/kolors
|
||||
title: Kolors
|
||||
- local: api/pipelines/latent_consistency_models
|
||||
@@ -565,8 +565,6 @@
|
||||
title: Sana
|
||||
- local: api/pipelines/sana_sprint
|
||||
title: Sana Sprint
|
||||
- local: api/pipelines/sana_video
|
||||
title: Sana Video
|
||||
- local: api/pipelines/self_attention_guidance
|
||||
title: Self-Attention Guidance
|
||||
- local: api/pipelines/semantic_stable_diffusion
|
||||
@@ -640,8 +638,6 @@
|
||||
title: HunyuanVideo
|
||||
- local: api/pipelines/i2vgenxl
|
||||
title: I2VGen-XL
|
||||
- local: api/pipelines/kandinsky5_video
|
||||
title: Kandinsky 5.0 Video
|
||||
- local: api/pipelines/latte
|
||||
title: Latte
|
||||
- local: api/pipelines/ltx_video
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
<!-- Copyright 2025 The SANA-Video Authors and HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# SanaVideoTransformer3DModel
|
||||
|
||||
A Diffusion Transformer model for 3D data (video) from [SANA-Video: Efficient Video Generation with Block Linear Diffusion Transformer](https://huggingface.co/papers/2509.24695) from NVIDIA and MIT HAN Lab, by Junsong Chen, Yuyang Zhao, Jincheng Yu, Ruihang Chu, Junyu Chen, Shuai Yang, Xianbang Wang, Yicheng Pan, Daquan Zhou, Huan Ling, Haozhe Liu, Hongwei Yi, Hao Zhang, Muyang Li, Yukang Chen, Han Cai, Sanja Fidler, Ping Luo, Song Han, Enze Xie.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*We introduce SANA-Video, a small diffusion model that can efficiently generate videos up to 720x1280 resolution and minute-length duration. SANA-Video synthesizes high-resolution, high-quality and long videos with strong text-video alignment at a remarkably fast speed, deployable on RTX 5090 GPU. Two core designs ensure our efficient, effective and long video generation: (1) Linear DiT: We leverage linear attention as the core operation, which is more efficient than vanilla attention given the large number of tokens processed in video generation. (2) Constant-Memory KV cache for Block Linear Attention: we design block-wise autoregressive approach for long video generation by employing a constant-memory state, derived from the cumulative properties of linear attention. This KV cache provides the Linear DiT with global context at a fixed memory cost, eliminating the need for a traditional KV cache and enabling efficient, minute-long video generation. In addition, we explore effective data filters and model training strategies, narrowing the training cost to 12 days on 64 H100 GPUs, which is only 1% of the cost of MovieGen. Given its low cost, SANA-Video achieves competitive performance compared to modern state-of-the-art small diffusion models (e.g., Wan 2.1-1.3B and SkyReel-V2-1.3B) while being 16x faster in measured latency. Moreover, SANA-Video can be deployed on RTX 5090 GPUs with NVFP4 precision, accelerating the inference speed of generating a 5-second 720p video from 71s to 29s (2.4x speedup). In summary, SANA-Video enables low-cost, high-quality video generation.*
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import SanaVideoTransformer3DModel
|
||||
import torch
|
||||
|
||||
transformer = SanaVideoTransformer3DModel.from_pretrained("Efficient-Large-Model/SANA-Video_2B_480p_diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## SanaVideoTransformer3DModel
|
||||
|
||||
[[autodoc]] SanaVideoTransformer3DModel
|
||||
|
||||
## Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
|
||||
|
||||
+4
-4
@@ -7,9 +7,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Kandinsky 5.0 Video
|
||||
# Kandinsky 5.0
|
||||
|
||||
Kandinsky 5.0 Video is created by the Kandinsky team: Alexey Letunovskiy, Maria Kovaleva, Ivan Kirillov, Lev Novitskiy, Denis Koposov, Dmitrii Mikhailov, Anna Averchenkova, Andrey Shutkin, Julia Agafonova, Olga Kim, Anastasiia Kargapoltseva, Nikita Kiselev, Anna Dmitrienko, Anastasia Maltseva, Kirill Chernyshev, Ilia Vasiliev, Viacheslav Vasilev, Vladimir Polovnikov, Yury Kolabushin, Alexander Belykh, Mikhail Mamaev, Anastasia Aliaskina, Tatiana Nikulina, Polina Gavrilova, Vladimir Arkhipkin, Vladimir Korviakov, Nikolai Gerasimenko, Denis Parkhomenko, Denis Dimitrov
|
||||
Kandinsky 5.0 is created by the Kandinsky team: Alexey Letunovskiy, Maria Kovaleva, Ivan Kirillov, Lev Novitskiy, Denis Koposov, Dmitrii Mikhailov, Anna Averchenkova, Andrey Shutkin, Julia Agafonova, Olga Kim, Anastasiia Kargapoltseva, Nikita Kiselev, Anna Dmitrienko, Anastasia Maltseva, Kirill Chernyshev, Ilia Vasiliev, Viacheslav Vasilev, Vladimir Polovnikov, Yury Kolabushin, Alexander Belykh, Mikhail Mamaev, Anastasia Aliaskina, Tatiana Nikulina, Polina Gavrilova, Vladimir Arkhipkin, Vladimir Korviakov, Nikolai Gerasimenko, Denis Parkhomenko, Denis Dimitrov
|
||||
|
||||
|
||||
Kandinsky 5.0 is a family of diffusion models for Video & Image generation. Kandinsky 5.0 T2V Lite is a lightweight video generation model (2B parameters) that ranks #1 among open-source models in its class. It outperforms larger models and offers the best understanding of Russian concepts in the open-source ecosystem.
|
||||
@@ -92,7 +92,7 @@ pipe = pipe.to("cuda")
|
||||
|
||||
pipe.transformer.set_attention_backend(
|
||||
"flex"
|
||||
) # <--- Sett attention bakend to Flex
|
||||
) # <--- Set attention backend to Flex
|
||||
pipe.transformer.compile(
|
||||
mode="max-autotune-no-cudagraphs",
|
||||
dynamic=True
|
||||
@@ -115,7 +115,7 @@ export_to_video(output, "output.mp4", fps=24, quality=9)
|
||||
```
|
||||
|
||||
### Diffusion Distilled model
|
||||
**⚠️ Warning!** all nocfg and diffusion distilled models should be infered wothout CFG (```guidance_scale=1.0```):
|
||||
**⚠️ Warning!** all nocfg and diffusion distilled models should be inferred without CFG (```guidance_scale=1.0```):
|
||||
|
||||
```python
|
||||
model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers"
|
||||
@@ -24,6 +24,9 @@ The abstract from the paper is:
|
||||
|
||||
*This paper presents SANA-Sprint, an efficient diffusion model for ultra-fast text-to-image (T2I) generation. SANA-Sprint is built on a pre-trained foundation model and augmented with hybrid distillation, dramatically reducing inference steps from 20 to 1-4. We introduce three key innovations: (1) We propose a training-free approach that transforms a pre-trained flow-matching model for continuous-time consistency distillation (sCM), eliminating costly training from scratch and achieving high training efficiency. Our hybrid distillation strategy combines sCM with latent adversarial distillation (LADD): sCM ensures alignment with the teacher model, while LADD enhances single-step generation fidelity. (2) SANA-Sprint is a unified step-adaptive model that achieves high-quality generation in 1-4 steps, eliminating step-specific training and improving efficiency. (3) We integrate ControlNet with SANA-Sprint for real-time interactive image generation, enabling instant visual feedback for user interaction. SANA-Sprint establishes a new Pareto frontier in speed-quality tradeoffs, achieving state-of-the-art performance with 7.59 FID and 0.74 GenEval in only 1 step — outperforming FLUX-schnell (7.94 FID / 0.71 GenEval) while being 10× faster (0.1s vs 1.1s on H100). It also achieves 0.1s (T2I) and 0.25s (ControlNet) latency for 1024×1024 images on H100, and 0.31s (T2I) on an RTX 4090, showcasing its exceptional efficiency and potential for AI-powered consumer applications (AIPC). Code and pre-trained models will be open-sourced.*
|
||||
|
||||
> [!TIP]
|
||||
> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
This pipeline was contributed by [lawrence-cj](https://github.com/lawrence-cj), [shuchen Xue](https://github.com/scxue) and [Enze Xie](https://github.com/xieenze). The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model](https://huggingface.co/Efficient-Large-Model/).
|
||||
|
||||
Available models:
|
||||
|
||||
@@ -1,102 +0,0 @@
|
||||
<!-- Copyright 2025 The SANA-Video Authors and HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# SanaVideoPipeline
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
|
||||
</div>
|
||||
|
||||
[SANA-Video: Efficient Video Generation with Block Linear Diffusion Transformer](https://huggingface.co/papers/2509.24695) from NVIDIA and MIT HAN Lab, by Junsong Chen, Yuyang Zhao, Jincheng Yu, Ruihang Chu, Junyu Chen, Shuai Yang, Xianbang Wang, Yicheng Pan, Daquan Zhou, Huan Ling, Haozhe Liu, Hongwei Yi, Hao Zhang, Muyang Li, Yukang Chen, Han Cai, Sanja Fidler, Ping Luo, Song Han, Enze Xie.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*We introduce SANA-Video, a small diffusion model that can efficiently generate videos up to 720x1280 resolution and minute-length duration. SANA-Video synthesizes high-resolution, high-quality and long videos with strong text-video alignment at a remarkably fast speed, deployable on RTX 5090 GPU. Two core designs ensure our efficient, effective and long video generation: (1) Linear DiT: We leverage linear attention as the core operation, which is more efficient than vanilla attention given the large number of tokens processed in video generation. (2) Constant-Memory KV cache for Block Linear Attention: we design block-wise autoregressive approach for long video generation by employing a constant-memory state, derived from the cumulative properties of linear attention. This KV cache provides the Linear DiT with global context at a fixed memory cost, eliminating the need for a traditional KV cache and enabling efficient, minute-long video generation. In addition, we explore effective data filters and model training strategies, narrowing the training cost to 12 days on 64 H100 GPUs, which is only 1% of the cost of MovieGen. Given its low cost, SANA-Video achieves competitive performance compared to modern state-of-the-art small diffusion models (e.g., Wan 2.1-1.3B and SkyReel-V2-1.3B) while being 16x faster in measured latency. Moreover, SANA-Video can be deployed on RTX 5090 GPUs with NVFP4 precision, accelerating the inference speed of generating a 5-second 720p video from 71s to 29s (2.4x speedup). In summary, SANA-Video enables low-cost, high-quality video generation. [this https URL](https://github.com/NVlabs/SANA).*
|
||||
|
||||
This pipeline was contributed by SANA Team. The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model](https://hf.co/collections/Efficient-Large-Model/sana-video).
|
||||
|
||||
Available models:
|
||||
|
||||
| Model | Recommended dtype |
|
||||
|:-----:|:-----------------:|
|
||||
| [`Efficient-Large-Model/SANA-Video_2B_480p_diffusers`](https://huggingface.co/Efficient-Large-Model/ANA-Video_2B_480p_diffusers) | `torch.bfloat16` |
|
||||
|
||||
Refer to [this](https://huggingface.co/collections/Efficient-Large-Model/sana-video) collection for more information.
|
||||
|
||||
Note: The recommended dtype mentioned is for the transformer weights. The text encoder and VAE weights must stay in `torch.bfloat16` or `torch.float32` for the model to work correctly. Please refer to the inference example below to see how to load the model with the recommended dtype.
|
||||
|
||||
## Quantization
|
||||
|
||||
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
|
||||
|
||||
Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`SanaVideoPipeline`] for inference with bitsandbytes.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, SanaVideoTransformer3DModel, SanaVideoPipeline
|
||||
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, AutoModel
|
||||
|
||||
quant_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
text_encoder_8bit = AutoModel.from_pretrained(
|
||||
"Efficient-Large-Model/SANA-Video_2B_480p_diffusers",
|
||||
subfolder="text_encoder",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
|
||||
transformer_8bit = SanaVideoTransformer3DModel.from_pretrained(
|
||||
"Efficient-Large-Model/SANA-Video_2B_480p_diffusers",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
pipeline = SanaVideoPipeline.from_pretrained(
|
||||
"Efficient-Large-Model/SANA-Video_2B_480p_diffusers",
|
||||
text_encoder=text_encoder_8bit,
|
||||
transformer=transformer_8bit,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="balanced",
|
||||
)
|
||||
|
||||
model_score = 30
|
||||
prompt = "Evening, backlight, side lighting, soft light, high contrast, mid-shot, centered composition, clean solo shot, warm color. A young Caucasian man stands in a forest, golden light glimmers on his hair as sunlight filters through the leaves. He wears a light shirt, wind gently blowing his hair and collar, light dances across his face with his movements. The background is blurred, with dappled light and soft tree shadows in the distance. The camera focuses on his lifted gaze, clear and emotional."
|
||||
negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience."
|
||||
motion_prompt = f" motion score: {model_score}."
|
||||
prompt = prompt + motion_prompt
|
||||
|
||||
output = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=480,
|
||||
width=832,
|
||||
num_frames=81,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=50
|
||||
).frames[0]
|
||||
export_to_video(output, "sana-video-output.mp4", fps=16)
|
||||
```
|
||||
|
||||
## SanaVideoPipeline
|
||||
|
||||
[[autodoc]] SanaVideoPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
|
||||
## SanaVideoPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.sana.pipeline_sana_video.SanaVideoPipelineOutput
|
||||
@@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# LoopSequentialPipelineBlocks
|
||||
|
||||
[`~modular_pipelines.LoopSequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a loop. Data flows circularly, using `inputs` and `intermediate_outputs`, and each block is run iteratively. This is typically used to create a denoising loop which is iterative by default.
|
||||
[`~modular_pipelines.LoopSequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a loop. Data flows circularly, using `intermediate_inputs` and `intermediate_outputs`, and each block is run iteratively. This is typically used to create a denoising loop which is iterative by default.
|
||||
|
||||
This guide shows you how to create [`~modular_pipelines.LoopSequentialPipelineBlocks`].
|
||||
|
||||
@@ -21,6 +21,7 @@ This guide shows you how to create [`~modular_pipelines.LoopSequentialPipelineBl
|
||||
[`~modular_pipelines.LoopSequentialPipelineBlocks`], is also known as the *loop wrapper* because it defines the loop structure, iteration variables, and configuration. Within the loop wrapper, you need the following variables.
|
||||
|
||||
- `loop_inputs` are user provided values and equivalent to [`~modular_pipelines.ModularPipelineBlocks.inputs`].
|
||||
- `loop_intermediate_inputs` are intermediate variables from the [`~modular_pipelines.PipelineState`] and equivalent to [`~modular_pipelines.ModularPipelineBlocks.intermediate_inputs`].
|
||||
- `loop_intermediate_outputs` are new intermediate variables created by the block and added to the [`~modular_pipelines.PipelineState`]. It is equivalent to [`~modular_pipelines.ModularPipelineBlocks.intermediate_outputs`].
|
||||
- `__call__` method defines the loop structure and iteration logic.
|
||||
|
||||
@@ -89,4 +90,4 @@ Add more loop blocks to run within each iteration with [`~modular_pipelines.Loop
|
||||
|
||||
```py
|
||||
loop = LoopWrapper.from_blocks_dict({"block1": LoopBlock(), "block2": LoopBlock})
|
||||
```
|
||||
```
|
||||
@@ -37,7 +37,17 @@ A [`~modular_pipelines.ModularPipelineBlocks`] requires `inputs`, and `intermedi
|
||||
]
|
||||
```
|
||||
|
||||
- `intermediate_outputs` are new values created by a block and added to the [`~modular_pipelines.PipelineState`]. The `intermediate_outputs` are available as `inputs` for subsequent blocks or available as the final output from running the pipeline.
|
||||
- `intermediate_inputs` are values typically created from a previous block but it can also be directly provided if no preceding block generates them. Unlike `inputs`, `intermediate_inputs` can be modified.
|
||||
|
||||
Use `InputParam` to define `intermediate_inputs`.
|
||||
|
||||
```py
|
||||
user_intermediate_inputs = [
|
||||
InputParam(name="processed_image", type_hint="torch.Tensor", description="image that has been preprocessed and normalized"),
|
||||
]
|
||||
```
|
||||
|
||||
- `intermediate_outputs` are new values created by a block and added to the [`~modular_pipelines.PipelineState`]. The `intermediate_outputs` are available as `intermediate_inputs` for subsequent blocks or available as the final output from running the pipeline.
|
||||
|
||||
Use `OutputParam` to define `intermediate_outputs`.
|
||||
|
||||
@@ -55,8 +65,8 @@ The intermediate inputs and outputs share data to connect blocks. They are acces
|
||||
|
||||
The computation a block performs is defined in the `__call__` method and it follows a specific structure.
|
||||
|
||||
1. Retrieve the [`~modular_pipelines.BlockState`] to get a local view of the `inputs`
|
||||
2. Implement the computation logic on the `inputs`.
|
||||
1. Retrieve the [`~modular_pipelines.BlockState`] to get a local view of the `inputs` and `intermediate_inputs`.
|
||||
2. Implement the computation logic on the `inputs` and `intermediate_inputs`.
|
||||
3. Update [`~modular_pipelines.PipelineState`] to push changes from the local [`~modular_pipelines.BlockState`] back to the global [`~modular_pipelines.PipelineState`].
|
||||
4. Return the components and state which becomes available to the next block.
|
||||
|
||||
@@ -66,7 +76,7 @@ def __call__(self, components, state):
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# Your computation logic here
|
||||
# block_state contains all your inputs
|
||||
# block_state contains all your inputs and intermediate_inputs
|
||||
# Access them like: block_state.image, block_state.processed_image
|
||||
|
||||
# Update the pipeline state with your updated block_states
|
||||
@@ -102,4 +112,4 @@ def __call__(self, components, state):
|
||||
unet = components.unet
|
||||
vae = components.vae
|
||||
scheduler = components.scheduler
|
||||
```
|
||||
```
|
||||
@@ -183,7 +183,7 @@ from diffusers.modular_pipelines import ComponentsManager
|
||||
components = ComponentManager()
|
||||
|
||||
dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", components_manager=components, collection="diffdiff")
|
||||
dd_pipeline.load_componenets(torch_dtype=torch.float16)
|
||||
dd_pipeline.load_default_componenets(torch_dtype=torch.float16)
|
||||
dd_pipeline.to("cuda")
|
||||
```
|
||||
|
||||
|
||||
@@ -12,11 +12,11 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# SequentialPipelineBlocks
|
||||
|
||||
[`~modular_pipelines.SequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a sequence. Data flows linearly from one block to the next using `inputs` and `intermediate_outputs`. Each block in [`~modular_pipelines.SequentialPipelineBlocks`] usually represents a step in the pipeline, and by combining them, you gradually build a pipeline.
|
||||
[`~modular_pipelines.SequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a sequence. Data flows linearly from one block to the next using `intermediate_inputs` and `intermediate_outputs`. Each block in [`~modular_pipelines.SequentialPipelineBlocks`] usually represents a step in the pipeline, and by combining them, you gradually build a pipeline.
|
||||
|
||||
This guide shows you how to connect two blocks into a [`~modular_pipelines.SequentialPipelineBlocks`].
|
||||
|
||||
Create two [`~modular_pipelines.ModularPipelineBlocks`]. The first block, `InputBlock`, outputs a `batch_size` value and the second block, `ImageEncoderBlock` uses `batch_size` as `inputs`.
|
||||
Create two [`~modular_pipelines.ModularPipelineBlocks`]. The first block, `InputBlock`, outputs a `batch_size` value and the second block, `ImageEncoderBlock` uses `batch_size` as `intermediate_inputs`.
|
||||
|
||||
<hfoptions id="sequential">
|
||||
<hfoption id="InputBlock">
|
||||
@@ -110,4 +110,4 @@ Inspect the sub-blocks in [`~modular_pipelines.SequentialPipelineBlocks`] by cal
|
||||
```py
|
||||
print(blocks)
|
||||
print(blocks.doc)
|
||||
```
|
||||
```
|
||||
@@ -104,8 +104,6 @@ To use your own dataset, there are 2 ways:
|
||||
- you can either provide your own folder as `--train_data_dir`
|
||||
- or you can upload your dataset to the hub (possibly as a private repo, if you prefer so), and simply pass the `--dataset_name` argument.
|
||||
|
||||
If your dataset contains 16 or 32-bit channels (for example, medical TIFFs), add the `--preserve_input_precision` flag so the preprocessing keeps the original precision while still training a 3-channel model. Precision still depends on the decoder: Pillow keeps 16-bit grayscale and float inputs, but many 16-bit RGB files are decoded as 8-bit RGB, and the flag cannot recover precision lost at load time.
|
||||
|
||||
Below, we explain both in more detail.
|
||||
|
||||
#### Provide the dataset as a folder
|
||||
|
||||
@@ -52,24 +52,6 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
||||
return res.expand(broadcast_shape)
|
||||
|
||||
|
||||
def _ensure_three_channels(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Ensure the tensor has exactly three channels (C, H, W) by repeating or truncating channels when needed.
|
||||
"""
|
||||
if tensor.ndim == 2:
|
||||
tensor = tensor.unsqueeze(0)
|
||||
channels = tensor.shape[0]
|
||||
if channels == 3:
|
||||
return tensor
|
||||
if channels == 1:
|
||||
return tensor.repeat(3, 1, 1)
|
||||
if channels == 2:
|
||||
return torch.cat([tensor, tensor[:1]], dim=0)
|
||||
if channels > 3:
|
||||
return tensor[:3]
|
||||
raise ValueError(f"Unsupported number of channels: {channels}")
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
@@ -278,11 +260,6 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--preserve_input_precision",
|
||||
action="store_true",
|
||||
help="Preserve 16/32-bit image precision by avoiding 8-bit RGB conversion while still producing 3-channel tensors.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
@@ -476,41 +453,19 @@ def main(args):
|
||||
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
|
||||
|
||||
# Preprocessing the datasets and DataLoaders creation.
|
||||
spatial_augmentations = [
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
|
||||
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
|
||||
]
|
||||
|
||||
augmentations = transforms.Compose(
|
||||
spatial_augmentations
|
||||
+ [
|
||||
[
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
|
||||
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
|
||||
precision_augmentations = transforms.Compose(
|
||||
[
|
||||
transforms.PILToTensor(),
|
||||
transforms.Lambda(_ensure_three_channels),
|
||||
transforms.ConvertImageDtype(torch.float32),
|
||||
]
|
||||
+ spatial_augmentations
|
||||
+ [transforms.Normalize([0.5], [0.5])]
|
||||
)
|
||||
|
||||
def transform_images(examples):
|
||||
processed = []
|
||||
for image in examples["image"]:
|
||||
if not args.preserve_input_precision:
|
||||
processed.append(augmentations(image.convert("RGB")))
|
||||
else:
|
||||
precise_image = image
|
||||
if precise_image.mode == "P":
|
||||
precise_image = precise_image.convert("RGB")
|
||||
processed.append(precision_augmentations(precise_image))
|
||||
return {"input": processed}
|
||||
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
|
||||
return {"input": images}
|
||||
|
||||
logger.info(f"Dataset size: {len(dataset)}")
|
||||
|
||||
|
||||
@@ -1,324 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
from termcolor import colored
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLWan,
|
||||
DPMSolverMultistepScheduler,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
SanaVideoPipeline,
|
||||
SanaVideoTransformer3DModel,
|
||||
UniPCMultistepScheduler,
|
||||
)
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
|
||||
CTX = init_empty_weights if is_accelerate_available else nullcontext
|
||||
|
||||
ckpt_ids = ["Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth"]
|
||||
# https://github.com/NVlabs/Sana/blob/main/inference_video_scripts/inference_sana_video.py
|
||||
|
||||
|
||||
def main(args):
|
||||
cache_dir_path = os.path.expanduser("~/.cache/huggingface/hub")
|
||||
|
||||
if args.orig_ckpt_path is None or args.orig_ckpt_path in ckpt_ids:
|
||||
ckpt_id = args.orig_ckpt_path or ckpt_ids[0]
|
||||
snapshot_download(
|
||||
repo_id=f"{'/'.join(ckpt_id.split('/')[:2])}",
|
||||
cache_dir=cache_dir_path,
|
||||
repo_type="model",
|
||||
)
|
||||
file_path = hf_hub_download(
|
||||
repo_id=f"{'/'.join(ckpt_id.split('/')[:2])}",
|
||||
filename=f"{'/'.join(ckpt_id.split('/')[2:])}",
|
||||
cache_dir=cache_dir_path,
|
||||
repo_type="model",
|
||||
)
|
||||
else:
|
||||
file_path = args.orig_ckpt_path
|
||||
|
||||
print(colored(f"Loading checkpoint from {file_path}", "green", attrs=["bold"]))
|
||||
all_state_dict = torch.load(file_path, weights_only=True)
|
||||
state_dict = all_state_dict.pop("state_dict")
|
||||
converted_state_dict = {}
|
||||
|
||||
# Patch embeddings.
|
||||
converted_state_dict["patch_embedding.weight"] = state_dict.pop("x_embedder.proj.weight")
|
||||
converted_state_dict["patch_embedding.bias"] = state_dict.pop("x_embedder.proj.bias")
|
||||
|
||||
# Caption projection.
|
||||
converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight")
|
||||
converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias")
|
||||
converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
|
||||
converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")
|
||||
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
|
||||
"t_embedder.mlp.0.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
|
||||
"t_embedder.mlp.2.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
|
||||
|
||||
# Shared norm.
|
||||
converted_state_dict["time_embed.linear.weight"] = state_dict.pop("t_block.1.weight")
|
||||
converted_state_dict["time_embed.linear.bias"] = state_dict.pop("t_block.1.bias")
|
||||
|
||||
# y norm
|
||||
converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
|
||||
|
||||
# scheduler
|
||||
flow_shift = 8.0
|
||||
|
||||
# model config
|
||||
layer_num = 20
|
||||
# Positional embedding interpolation scale.
|
||||
qk_norm = True
|
||||
|
||||
# sample size
|
||||
if args.video_size == 480:
|
||||
sample_size = 30 # Wan-VAE: 8xp2 downsample factor
|
||||
patch_size = (1, 2, 2)
|
||||
elif args.video_size == 720:
|
||||
sample_size = 22 # Wan-VAE: 32xp1 downsample factor
|
||||
patch_size = (1, 1, 1)
|
||||
else:
|
||||
raise ValueError(f"Video size {args.video_size} is not supported.")
|
||||
|
||||
for depth in range(layer_num):
|
||||
# Transformer blocks.
|
||||
converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
|
||||
f"blocks.{depth}.scale_shift_table"
|
||||
)
|
||||
|
||||
# Linear Attention is all you need 🤘
|
||||
# Self attention.
|
||||
q, k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.weight"), 3, dim=0)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
|
||||
if qk_norm is not None:
|
||||
# Add Q/K normalization for self-attention (attn1) - needed for Sana-Sprint and Sana-1.5
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_q.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.attn.q_norm.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_k.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.attn.k_norm.weight"
|
||||
)
|
||||
# Projection.
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.attn.proj.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict.pop(
|
||||
f"blocks.{depth}.attn.proj.bias"
|
||||
)
|
||||
|
||||
# Feed-forward.
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.mlp.inverted_conv.conv.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.bias"] = state_dict.pop(
|
||||
f"blocks.{depth}.mlp.inverted_conv.conv.bias"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.mlp.depth_conv.conv.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.bias"] = state_dict.pop(
|
||||
f"blocks.{depth}.mlp.depth_conv.conv.bias"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_point.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.mlp.point_conv.conv.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_temp.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.mlp.t_conv.weight"
|
||||
)
|
||||
|
||||
# Cross-attention.
|
||||
q = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.weight")
|
||||
q_bias = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.bias")
|
||||
k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.weight"), 2, dim=0)
|
||||
k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.bias"), 2, dim=0)
|
||||
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.weight"] = k
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias
|
||||
if qk_norm is not None:
|
||||
# Add Q/K normalization for cross-attention (attn2) - needed for Sana-Sprint and Sana-1.5
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_q.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.cross_attn.q_norm.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_k.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.cross_attn.k_norm.weight"
|
||||
)
|
||||
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.cross_attn.proj.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.bias"] = state_dict.pop(
|
||||
f"blocks.{depth}.cross_attn.proj.bias"
|
||||
)
|
||||
|
||||
# Final block.
|
||||
converted_state_dict["proj_out.weight"] = state_dict.pop("final_layer.linear.weight")
|
||||
converted_state_dict["proj_out.bias"] = state_dict.pop("final_layer.linear.bias")
|
||||
converted_state_dict["scale_shift_table"] = state_dict.pop("final_layer.scale_shift_table")
|
||||
|
||||
# Transformer
|
||||
with CTX():
|
||||
transformer_kwargs = {
|
||||
"in_channels": 16,
|
||||
"out_channels": 16,
|
||||
"num_attention_heads": 20,
|
||||
"attention_head_dim": 112,
|
||||
"num_layers": 20,
|
||||
"num_cross_attention_heads": 20,
|
||||
"cross_attention_head_dim": 112,
|
||||
"cross_attention_dim": 2240,
|
||||
"caption_channels": 2304,
|
||||
"mlp_ratio": 3.0,
|
||||
"attention_bias": False,
|
||||
"sample_size": sample_size,
|
||||
"patch_size": patch_size,
|
||||
"norm_elementwise_affine": False,
|
||||
"norm_eps": 1e-6,
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"rope_max_seq_len": 1024,
|
||||
}
|
||||
|
||||
transformer = SanaVideoTransformer3DModel(**transformer_kwargs)
|
||||
|
||||
transformer.load_state_dict(converted_state_dict, strict=True, assign=True)
|
||||
|
||||
try:
|
||||
state_dict.pop("y_embedder.y_embedding")
|
||||
state_dict.pop("pos_embed")
|
||||
state_dict.pop("logvar_linear.weight")
|
||||
state_dict.pop("logvar_linear.bias")
|
||||
except KeyError:
|
||||
print("y_embedder.y_embedding or pos_embed not found in the state_dict")
|
||||
|
||||
assert len(state_dict) == 0, f"State dict is not empty, {state_dict.keys()}"
|
||||
|
||||
num_model_params = sum(p.numel() for p in transformer.parameters())
|
||||
print(f"Total number of transformer parameters: {num_model_params}")
|
||||
|
||||
transformer = transformer.to(weight_dtype)
|
||||
|
||||
if not args.save_full_pipeline:
|
||||
print(
|
||||
colored(
|
||||
f"Only saving transformer model of {args.model_type}. "
|
||||
f"Set --save_full_pipeline to save the whole Pipeline",
|
||||
"green",
|
||||
attrs=["bold"],
|
||||
)
|
||||
)
|
||||
transformer.save_pretrained(
|
||||
os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB"
|
||||
)
|
||||
else:
|
||||
print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"]))
|
||||
# VAE
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
|
||||
)
|
||||
|
||||
# Text Encoder
|
||||
text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it"
|
||||
tokenizer = AutoTokenizer.from_pretrained(text_encoder_model_path)
|
||||
tokenizer.padding_side = "right"
|
||||
text_encoder = AutoModelForCausalLM.from_pretrained(
|
||||
text_encoder_model_path, torch_dtype=torch.bfloat16
|
||||
).get_decoder()
|
||||
|
||||
# Choose the appropriate pipeline and scheduler based on model type
|
||||
# Original Sana scheduler
|
||||
if args.scheduler_type == "flow-dpm_solver":
|
||||
scheduler = DPMSolverMultistepScheduler(
|
||||
flow_shift=flow_shift,
|
||||
use_flow_sigmas=True,
|
||||
prediction_type="flow_prediction",
|
||||
)
|
||||
elif args.scheduler_type == "flow-euler":
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)
|
||||
elif args.scheduler_type == "uni-pc":
|
||||
scheduler = UniPCMultistepScheduler(
|
||||
prediction_type="flow_prediction",
|
||||
use_flow_sigmas=True,
|
||||
num_train_timesteps=1000,
|
||||
flow_shift=flow_shift,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Scheduler type {args.scheduler_type} is not supported")
|
||||
|
||||
pipe = SanaVideoPipeline(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
transformer=transformer,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
|
||||
DTYPE_MAPPING = {
|
||||
"fp32": torch.float32,
|
||||
"fp16": torch.float16,
|
||||
"bf16": torch.bfloat16,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--orig_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--video_size",
|
||||
default=480,
|
||||
type=int,
|
||||
choices=[480, 720],
|
||||
required=False,
|
||||
help="Video size of pretrained model, 480 or 720.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default="SanaVideo",
|
||||
type=str,
|
||||
choices=[
|
||||
"SanaVideo",
|
||||
],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scheduler_type",
|
||||
default="flow-dpm_solver",
|
||||
type=str,
|
||||
choices=["flow-dpm_solver", "flow-euler", "uni-pc"],
|
||||
help="Scheduler type to use.",
|
||||
)
|
||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
|
||||
parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipeline elements in one.")
|
||||
parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
weight_dtype = DTYPE_MAPPING[args.dtype]
|
||||
|
||||
main(args)
|
||||
@@ -122,7 +122,7 @@ _deps = [
|
||||
"pytest",
|
||||
"pytest-timeout",
|
||||
"pytest-xdist",
|
||||
"python>=3.8.0",
|
||||
"python>=3.9.0",
|
||||
"ruff==0.9.10",
|
||||
"safetensors>=0.3.1",
|
||||
"sentencepiece>=0.1.91,!=0.1.92",
|
||||
@@ -287,7 +287,7 @@ setup(
|
||||
packages=find_packages("src"),
|
||||
package_data={"diffusers": ["py.typed"]},
|
||||
include_package_data=True,
|
||||
python_requires=">=3.8.0",
|
||||
python_requires=">=3.10.0",
|
||||
install_requires=list(install_requires),
|
||||
extras_require=extras,
|
||||
entry_points={"console_scripts": ["diffusers-cli=diffusers.commands.diffusers_cli:main"]},
|
||||
|
||||
@@ -246,7 +246,6 @@ else:
|
||||
"QwenImageTransformer2DModel",
|
||||
"SanaControlNetModel",
|
||||
"SanaTransformer2DModel",
|
||||
"SanaVideoTransformer3DModel",
|
||||
"SD3ControlNetModel",
|
||||
"SD3MultiControlNetModel",
|
||||
"SD3Transformer2DModel",
|
||||
@@ -545,7 +544,6 @@ else:
|
||||
"SanaPipeline",
|
||||
"SanaSprintImg2ImgPipeline",
|
||||
"SanaSprintPipeline",
|
||||
"SanaVideoPipeline",
|
||||
"SemanticStableDiffusionPipeline",
|
||||
"ShapEImg2ImgPipeline",
|
||||
"ShapEPipeline",
|
||||
@@ -953,7 +951,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
QwenImageTransformer2DModel,
|
||||
SanaControlNetModel,
|
||||
SanaTransformer2DModel,
|
||||
SanaVideoTransformer3DModel,
|
||||
SD3ControlNetModel,
|
||||
SD3MultiControlNetModel,
|
||||
SD3Transformer2DModel,
|
||||
@@ -1222,7 +1219,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
SanaPipeline,
|
||||
SanaSprintImg2ImgPipeline,
|
||||
SanaSprintPipeline,
|
||||
SanaVideoPipeline,
|
||||
SemanticStableDiffusionPipeline,
|
||||
ShapEImg2ImgPipeline,
|
||||
ShapEPipeline,
|
||||
|
||||
+12
-12
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from .configuration_utils import ConfigMixin, register_to_config
|
||||
from .utils import CONFIG_NAME
|
||||
@@ -33,13 +33,13 @@ class PipelineCallback(ConfigMixin):
|
||||
raise ValueError("cutoff_step_ratio must be a float between 0.0 and 1.0.")
|
||||
|
||||
@property
|
||||
def tensor_inputs(self) -> List[str]:
|
||||
def tensor_inputs(self) -> list[str]:
|
||||
raise NotImplementedError(f"You need to set the attribute `tensor_inputs` for {self.__class__}")
|
||||
|
||||
def callback_fn(self, pipeline, step_index, timesteps, callback_kwargs) -> Dict[str, Any]:
|
||||
def callback_fn(self, pipeline, step_index, timesteps, callback_kwargs) -> dict[str, Any]:
|
||||
raise NotImplementedError(f"You need to implement the method `callback_fn` for {self.__class__}")
|
||||
|
||||
def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
|
||||
def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> dict[str, Any]:
|
||||
return self.callback_fn(pipeline, step_index, timestep, callback_kwargs)
|
||||
|
||||
|
||||
@@ -49,14 +49,14 @@ class MultiPipelineCallbacks:
|
||||
provides a unified interface for calling all of them.
|
||||
"""
|
||||
|
||||
def __init__(self, callbacks: List[PipelineCallback]):
|
||||
def __init__(self, callbacks: list[PipelineCallback]):
|
||||
self.callbacks = callbacks
|
||||
|
||||
@property
|
||||
def tensor_inputs(self) -> List[str]:
|
||||
def tensor_inputs(self) -> list[str]:
|
||||
return [input for callback in self.callbacks for input in callback.tensor_inputs]
|
||||
|
||||
def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
|
||||
def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> dict[str, Any]:
|
||||
"""
|
||||
Calls all the callbacks in order with the given arguments and returns the final callback_kwargs.
|
||||
"""
|
||||
@@ -76,7 +76,7 @@ class SDCFGCutoffCallback(PipelineCallback):
|
||||
|
||||
tensor_inputs = ["prompt_embeds"]
|
||||
|
||||
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
|
||||
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> dict[str, Any]:
|
||||
cutoff_step_ratio = self.config.cutoff_step_ratio
|
||||
cutoff_step_index = self.config.cutoff_step_index
|
||||
|
||||
@@ -109,7 +109,7 @@ class SDXLCFGCutoffCallback(PipelineCallback):
|
||||
"add_time_ids",
|
||||
]
|
||||
|
||||
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
|
||||
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> dict[str, Any]:
|
||||
cutoff_step_ratio = self.config.cutoff_step_ratio
|
||||
cutoff_step_index = self.config.cutoff_step_index
|
||||
|
||||
@@ -152,7 +152,7 @@ class SDXLControlnetCFGCutoffCallback(PipelineCallback):
|
||||
"image",
|
||||
]
|
||||
|
||||
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
|
||||
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> dict[str, Any]:
|
||||
cutoff_step_ratio = self.config.cutoff_step_ratio
|
||||
cutoff_step_index = self.config.cutoff_step_index
|
||||
|
||||
@@ -195,7 +195,7 @@ class IPAdapterScaleCutoffCallback(PipelineCallback):
|
||||
|
||||
tensor_inputs = []
|
||||
|
||||
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
|
||||
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> dict[str, Any]:
|
||||
cutoff_step_ratio = self.config.cutoff_step_ratio
|
||||
cutoff_step_index = self.config.cutoff_step_index
|
||||
|
||||
@@ -219,7 +219,7 @@ class SD3CFGCutoffCallback(PipelineCallback):
|
||||
|
||||
tensor_inputs = ["prompt_embeds", "pooled_prompt_embeds"]
|
||||
|
||||
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
|
||||
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> dict[str, Any]:
|
||||
cutoff_step_ratio = self.config.cutoff_step_ratio
|
||||
cutoff_step_index = self.config.cutoff_step_index
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ import os
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
from huggingface_hub import DDUFEntry, create_repo, hf_hub_download
|
||||
@@ -94,10 +94,10 @@ class ConfigMixin:
|
||||
Class attributes:
|
||||
- **config_name** (`str`) -- A filename under which the config should stored when calling
|
||||
[`~ConfigMixin.save_config`] (should be overridden by parent class).
|
||||
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
|
||||
- **ignore_for_config** (`list[str]`) -- A list of attributes that should not be saved in the config (should be
|
||||
overridden by subclass).
|
||||
- **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
|
||||
- **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the `init` function
|
||||
- **_deprecated_kwargs** (`list[str]`) -- Keyword arguments that are deprecated. Note that the `init` function
|
||||
should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
|
||||
subclass).
|
||||
"""
|
||||
@@ -143,7 +143,7 @@ class ConfigMixin:
|
||||
|
||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
||||
|
||||
def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
||||
def save_config(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs):
|
||||
"""
|
||||
Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the
|
||||
[`~ConfigMixin.from_config`] class method.
|
||||
@@ -155,7 +155,7 @@ class ConfigMixin:
|
||||
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
|
||||
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||||
namespace).
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
kwargs (`dict[str, Any]`, *optional*):
|
||||
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
if os.path.isfile(save_directory):
|
||||
@@ -189,13 +189,13 @@ class ConfigMixin:
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs
|
||||
) -> Union[Self, Tuple[Self, Dict[str, Any]]]:
|
||||
cls, config: FrozenDict | dict[str, Any] = None, return_unused_kwargs=False, **kwargs
|
||||
) -> Self | tuple[Self, dict[str, Any]]:
|
||||
r"""
|
||||
Instantiate a Python class from a config dictionary.
|
||||
|
||||
Parameters:
|
||||
config (`Dict[str, Any]`):
|
||||
config (`dict[str, Any]`):
|
||||
A config dictionary from which the Python class is instantiated. Make sure to only load configuration
|
||||
files of compatible classes.
|
||||
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
||||
@@ -292,11 +292,11 @@ class ConfigMixin:
|
||||
@validate_hf_hub_args
|
||||
def load_config(
|
||||
cls,
|
||||
pretrained_model_name_or_path: Union[str, os.PathLike],
|
||||
pretrained_model_name_or_path: str | os.PathLike,
|
||||
return_unused_kwargs=False,
|
||||
return_commit_hash=False,
|
||||
**kwargs,
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
r"""
|
||||
Load a model or scheduler configuration.
|
||||
|
||||
@@ -315,7 +315,7 @@ class ConfigMixin:
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
proxies (`dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
output_loading_info(`bool`, *optional*, defaults to `False`):
|
||||
@@ -352,7 +352,7 @@ class ConfigMixin:
|
||||
_ = kwargs.pop("mirror", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
user_agent = kwargs.pop("user_agent", {})
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
|
||||
dduf_entries: Optional[dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
|
||||
|
||||
user_agent = {**user_agent, "file_type": "config"}
|
||||
user_agent = http_user_agent(user_agent)
|
||||
@@ -563,9 +563,7 @@ class ConfigMixin:
|
||||
return init_dict, unused_kwargs, hidden_config_dict
|
||||
|
||||
@classmethod
|
||||
def _dict_from_json_file(
|
||||
cls, json_file: Union[str, os.PathLike], dduf_entries: Optional[Dict[str, DDUFEntry]] = None
|
||||
):
|
||||
def _dict_from_json_file(cls, json_file: str | os.PathLike, dduf_entries: Optional[dict[str, DDUFEntry]] = None):
|
||||
if dduf_entries:
|
||||
text = dduf_entries[json_file].read_text()
|
||||
else:
|
||||
@@ -577,12 +575,12 @@ class ConfigMixin:
|
||||
return f"{self.__class__.__name__} {self.to_json_string()}"
|
||||
|
||||
@property
|
||||
def config(self) -> Dict[str, Any]:
|
||||
def config(self) -> dict[str, Any]:
|
||||
"""
|
||||
Returns the config of the class as a frozen dictionary
|
||||
|
||||
Returns:
|
||||
`Dict[str, Any]`: Config of the class.
|
||||
`dict[str, Any]`: Config of the class.
|
||||
"""
|
||||
return self._internal_dict
|
||||
|
||||
@@ -625,7 +623,7 @@ class ConfigMixin:
|
||||
|
||||
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
||||
|
||||
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
|
||||
def to_json_file(self, json_file_path: str | os.PathLike):
|
||||
"""
|
||||
Save the configuration instance's parameters to a JSON file.
|
||||
|
||||
@@ -637,7 +635,7 @@ class ConfigMixin:
|
||||
writer.write(self.to_json_string())
|
||||
|
||||
@classmethod
|
||||
def _get_config_file_from_dduf(cls, pretrained_model_name_or_path: str, dduf_entries: Dict[str, DDUFEntry]):
|
||||
def _get_config_file_from_dduf(cls, pretrained_model_name_or_path: str, dduf_entries: dict[str, DDUFEntry]):
|
||||
# paths inside a DDUF file must always be "/"
|
||||
config_file = (
|
||||
cls.config_name
|
||||
@@ -756,7 +754,7 @@ class LegacyConfigMixin(ConfigMixin):
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
|
||||
def from_config(cls, config: FrozenDict | dict[str, Any] = None, return_unused_kwargs=False, **kwargs):
|
||||
# To prevent dependency import problem.
|
||||
from .models.model_loading_utils import _fetch_remapped_cls_from_config
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ deps = {
|
||||
"pytest": "pytest",
|
||||
"pytest-timeout": "pytest-timeout",
|
||||
"pytest-xdist": "pytest-xdist",
|
||||
"python": "python>=3.8.0",
|
||||
"python": "python>=3.9.0",
|
||||
"ruff": "ruff==0.9.10",
|
||||
"safetensors": "safetensors>=0.3.1",
|
||||
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
|
||||
|
||||
@@ -12,8 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -77,7 +79,7 @@ class AdaptiveProjectedGuidance(BaseGuidance):
|
||||
self.use_original_formulation = use_original_formulation
|
||||
self.momentum_buffer = None
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
|
||||
if self._step == 0:
|
||||
if self.adaptive_projected_guidance_momentum is not None:
|
||||
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
|
||||
|
||||
@@ -12,8 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -36,10 +38,10 @@ class AutoGuidance(BaseGuidance):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
auto_guidance_layers (`int` or `List[int]`, *optional*):
|
||||
auto_guidance_layers (`int` or `list[int]`, *optional*):
|
||||
The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
|
||||
provided, `skip_layer_config` must be provided.
|
||||
auto_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
|
||||
auto_guidance_config (`LayerSkipConfig` or `list[LayerSkipConfig]`, *optional*):
|
||||
The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
|
||||
`LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
|
||||
dropout (`float`, *optional*):
|
||||
@@ -65,8 +67,8 @@ class AutoGuidance(BaseGuidance):
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
auto_guidance_layers: Optional[Union[int, List[int]]] = None,
|
||||
auto_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
|
||||
auto_guidance_layers: Optional[int | list[int]] = None,
|
||||
auto_guidance_config: LayerSkipConfig | list[LayerSkipConfig] | dict[str, Any] = None,
|
||||
dropout: Optional[float] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
@@ -133,7 +135,7 @@ class AutoGuidance(BaseGuidance):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
||||
registry.remove_hook(name, recurse=True)
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
|
||||
@@ -12,8 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -91,7 +93,7 @@ class ClassifierFreeGuidance(BaseGuidance):
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
|
||||
@@ -12,8 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -77,7 +79,7 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
|
||||
@@ -12,8 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -37,7 +39,7 @@ else:
|
||||
build_laplacian_pyramid_func = None
|
||||
|
||||
|
||||
def project(v0: torch.Tensor, v1: torch.Tensor, upcast_to_double: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def project(v0: torch.Tensor, v1: torch.Tensor, upcast_to_double: bool = True) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Project vector v0 onto vector v1, returning the parallel and orthogonal components of v0. Implementation from paper
|
||||
(Algorithm 2).
|
||||
@@ -58,7 +60,7 @@ def project(v0: torch.Tensor, v1: torch.Tensor, upcast_to_double: bool = True) -
|
||||
return v0_parallel, v0_orthogonal
|
||||
|
||||
|
||||
def build_image_from_pyramid(pyramid: List[torch.Tensor]) -> torch.Tensor:
|
||||
def build_image_from_pyramid(pyramid: list[torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Recovers the data space latents from the Laplacian pyramid frequency space. Implementation from the paper
|
||||
(Algorithm 2).
|
||||
@@ -99,19 +101,19 @@ class FrequencyDecoupledGuidance(BaseGuidance):
|
||||
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
|
||||
|
||||
Args:
|
||||
guidance_scales (`List[float]`, defaults to `[10.0, 5.0]`):
|
||||
guidance_scales (`list[float]`, defaults to `[10.0, 5.0]`):
|
||||
The scale parameter for frequency-decoupled guidance for each frequency component, listed from highest
|
||||
frequency level to lowest. Higher values result in stronger conditioning on the text prompt, while lower
|
||||
values allow for more freedom in generation. Higher values may lead to saturation and deterioration of
|
||||
image quality. The FDG authors recommend using higher guidance scales for higher frequency components and
|
||||
lower guidance scales for lower frequency components (so `guidance_scales` should typically be sorted in
|
||||
descending order).
|
||||
guidance_rescale (`float` or `List[float]`, defaults to `0.0`):
|
||||
guidance_rescale (`float` or `list[float]`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891). If a list is supplied, it should be the same length as
|
||||
`guidance_scales`.
|
||||
parallel_weights (`float` or `List[float]`, *optional*):
|
||||
parallel_weights (`float` or `list[float]`, *optional*):
|
||||
Optional weights for the parallel component of each frequency component of the projected CFG shift. If not
|
||||
set, the weights will default to `1.0` for all components, which corresponds to using the normal CFG shift
|
||||
(that is, equal weights for the parallel and orthogonal components). If set, a value in `[0, 1]` is
|
||||
@@ -120,10 +122,10 @@ class FrequencyDecoupledGuidance(BaseGuidance):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float` or `List[float]`, defaults to `0.0`):
|
||||
start (`float` or `list[float]`, defaults to `0.0`):
|
||||
The fraction of the total number of denoising steps after which guidance starts. If a list is supplied, it
|
||||
should be the same length as `guidance_scales`.
|
||||
stop (`float` or `List[float]`, defaults to `1.0`):
|
||||
stop (`float` or `list[float]`, defaults to `1.0`):
|
||||
The fraction of the total number of denoising steps after which guidance stops. If a list is supplied, it
|
||||
should be the same length as `guidance_scales`.
|
||||
guidance_rescale_space (`str`, defaults to `"data"`):
|
||||
@@ -141,12 +143,12 @@ class FrequencyDecoupledGuidance(BaseGuidance):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scales: Union[List[float], Tuple[float]] = [10.0, 5.0],
|
||||
guidance_rescale: Union[float, List[float], Tuple[float]] = 0.0,
|
||||
parallel_weights: Optional[Union[float, List[float], Tuple[float]]] = None,
|
||||
guidance_scales: list[float] | tuple[float] = [10.0, 5.0],
|
||||
guidance_rescale: float | list[float] | tuple[float] = 0.0,
|
||||
parallel_weights: Optional[float | list[float] | tuple[float]] = None,
|
||||
use_original_formulation: bool = False,
|
||||
start: Union[float, List[float], Tuple[float]] = 0.0,
|
||||
stop: Union[float, List[float], Tuple[float]] = 1.0,
|
||||
start: float | list[float] | tuple[float] = 0.0,
|
||||
stop: float | list[float] | tuple[float] = 1.0,
|
||||
guidance_rescale_space: str = "data",
|
||||
upcast_to_double: bool = True,
|
||||
enabled: bool = True,
|
||||
@@ -218,7 +220,7 @@ class FrequencyDecoupledGuidance(BaseGuidance):
|
||||
f"({len(self.guidance_scales)})"
|
||||
)
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
|
||||
@@ -12,8 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
@@ -51,8 +53,8 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
self._num_inference_steps: int = None
|
||||
self._timestep: torch.LongTensor = None
|
||||
self._count_prepared = 0
|
||||
self._input_fields: Dict[str, Union[str, Tuple[str, str]]] = None
|
||||
self._enabled = enabled
|
||||
self._input_fields: dict[str, str | tuple[str, str]] = None
|
||||
self._enabled = True
|
||||
|
||||
if not (0.0 <= start < 1.0):
|
||||
raise ValueError(f"Expected `start` to be between 0.0 and 1.0, but got {start}.")
|
||||
@@ -101,7 +103,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
self._timestep = timestep
|
||||
self._count_prepared = 0
|
||||
|
||||
def get_state(self) -> Dict[str, Any]:
|
||||
def get_state(self) -> dict[str, Any]:
|
||||
"""
|
||||
Returns the current state of the guidance technique as a dictionary. The state variables will be included in
|
||||
the __repr__ method. Returns:
|
||||
@@ -163,10 +165,10 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
"""
|
||||
pass
|
||||
|
||||
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
||||
def prepare_inputs(self, data: "BlockState") -> list["BlockState"]:
|
||||
raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")
|
||||
|
||||
def __call__(self, data: List["BlockState"]) -> Any:
|
||||
def __call__(self, data: list["BlockState"]) -> Any:
|
||||
if not all(hasattr(d, "noise_pred") for d in data):
|
||||
raise ValueError("Expected all data to have `noise_pred` attribute.")
|
||||
if len(data) != self.num_conditions:
|
||||
@@ -194,7 +196,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
@classmethod
|
||||
def _prepare_batch(
|
||||
cls,
|
||||
data: Dict[str, Tuple[torch.Tensor, torch.Tensor]],
|
||||
data: dict[str, tuple[torch.Tensor, torch.Tensor]],
|
||||
tuple_index: int,
|
||||
identifier: str,
|
||||
) -> "BlockState":
|
||||
@@ -203,7 +205,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
`BaseGuidance` class. It prepares the batch based on the provided tuple index.
|
||||
|
||||
Args:
|
||||
input_fields (`Dict[str, Union[str, Tuple[str, str]]]`):
|
||||
input_fields (`dict[str, Union[str, tuple[str, str]]]`):
|
||||
A dictionary where the keys are the names of the fields that will be used to store the data once it is
|
||||
prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
|
||||
to look up the required data provided for preparation. If a string is provided, it will be used as the
|
||||
@@ -238,7 +240,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
@validate_hf_hub_args
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
|
||||
pretrained_model_name_or_path: Optional[str | os.PathLike] = None,
|
||||
subfolder: Optional[str] = None,
|
||||
return_unused_kwargs=False,
|
||||
**kwargs,
|
||||
@@ -265,7 +267,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
proxies (`dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
output_loading_info(`bool`, *optional*, defaults to `False`):
|
||||
@@ -295,7 +297,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
)
|
||||
return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)
|
||||
|
||||
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
||||
def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs):
|
||||
"""
|
||||
Save a guider configuration object to a directory so that it can be reloaded using the
|
||||
[`~BaseGuidance.from_pretrained`] class method.
|
||||
@@ -307,7 +309,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
|
||||
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||||
namespace).
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
kwargs (`dict[str, Any]`, *optional*):
|
||||
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
|
||||
|
||||
@@ -12,8 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -58,10 +60,10 @@ class PerturbedAttentionGuidance(BaseGuidance):
|
||||
The fraction of the total number of denoising steps after which perturbed attention guidance starts.
|
||||
perturbed_guidance_stop (`float`, defaults to `0.2`):
|
||||
The fraction of the total number of denoising steps after which perturbed attention guidance stops.
|
||||
perturbed_guidance_layers (`int` or `List[int]`, *optional*):
|
||||
perturbed_guidance_layers (`int` or `list[int]`, *optional*):
|
||||
The layer indices to apply perturbed attention guidance to. Can be a single integer or a list of integers.
|
||||
If not provided, `perturbed_guidance_config` must be provided.
|
||||
perturbed_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
|
||||
perturbed_guidance_config (`LayerSkipConfig` or `list[LayerSkipConfig]`, *optional*):
|
||||
The configuration for the perturbed attention guidance. Can be a single `LayerSkipConfig` or a list of
|
||||
`LayerSkipConfig`. If not provided, `perturbed_guidance_layers` must be provided.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
@@ -92,8 +94,8 @@ class PerturbedAttentionGuidance(BaseGuidance):
|
||||
perturbed_guidance_scale: float = 2.8,
|
||||
perturbed_guidance_start: float = 0.01,
|
||||
perturbed_guidance_stop: float = 0.2,
|
||||
perturbed_guidance_layers: Optional[Union[int, List[int]]] = None,
|
||||
perturbed_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
|
||||
perturbed_guidance_layers: Optional[int | list[int]] = None,
|
||||
perturbed_guidance_config: LayerSkipConfig | list[LayerSkipConfig] | dict[str, Any] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
@@ -169,7 +171,7 @@ class PerturbedAttentionGuidance(BaseGuidance):
|
||||
registry.remove_hook(hook_name, recurse=True)
|
||||
|
||||
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_inputs
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
|
||||
if self.num_conditions == 1:
|
||||
tuple_indices = [0]
|
||||
input_predictions = ["pred_cond"]
|
||||
|
||||
@@ -12,8 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -64,11 +66,11 @@ class SkipLayerGuidance(BaseGuidance):
|
||||
The fraction of the total number of denoising steps after which skip layer guidance starts.
|
||||
skip_layer_guidance_stop (`float`, defaults to `0.2`):
|
||||
The fraction of the total number of denoising steps after which skip layer guidance stops.
|
||||
skip_layer_guidance_layers (`int` or `List[int]`, *optional*):
|
||||
skip_layer_guidance_layers (`int` or `list[int]`, *optional*):
|
||||
The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
|
||||
provided, `skip_layer_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion
|
||||
3.5 Medium.
|
||||
skip_layer_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
|
||||
skip_layer_config (`LayerSkipConfig` or `list[LayerSkipConfig]`, *optional*):
|
||||
The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
|
||||
`LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
@@ -94,8 +96,8 @@ class SkipLayerGuidance(BaseGuidance):
|
||||
skip_layer_guidance_scale: float = 2.8,
|
||||
skip_layer_guidance_start: float = 0.01,
|
||||
skip_layer_guidance_stop: float = 0.2,
|
||||
skip_layer_guidance_layers: Optional[Union[int, List[int]]] = None,
|
||||
skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
|
||||
skip_layer_guidance_layers: Optional[int | list[int]] = None,
|
||||
skip_layer_config: LayerSkipConfig | list[LayerSkipConfig] | dict[str, Any] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
@@ -165,7 +167,7 @@ class SkipLayerGuidance(BaseGuidance):
|
||||
for hook_name in self._skip_layer_hook_names:
|
||||
registry.remove_hook(hook_name, recurse=True)
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
|
||||
if self.num_conditions == 1:
|
||||
tuple_indices = [0]
|
||||
input_predictions = ["pred_cond"]
|
||||
|
||||
@@ -12,8 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -54,11 +56,11 @@ class SmoothedEnergyGuidance(BaseGuidance):
|
||||
The fraction of the total number of denoising steps after which smoothed energy guidance starts.
|
||||
seg_guidance_stop (`float`, defaults to `1.0`):
|
||||
The fraction of the total number of denoising steps after which smoothed energy guidance stops.
|
||||
seg_guidance_layers (`int` or `List[int]`, *optional*):
|
||||
seg_guidance_layers (`int` or `list[int]`, *optional*):
|
||||
The layer indices to apply smoothed energy guidance to. Can be a single integer or a list of integers. If
|
||||
not provided, `seg_guidance_config` must be provided. The recommended values are `[7, 8, 9]` for Stable
|
||||
Diffusion 3.5 Medium.
|
||||
seg_guidance_config (`SmoothedEnergyGuidanceConfig` or `List[SmoothedEnergyGuidanceConfig]`, *optional*):
|
||||
seg_guidance_config (`SmoothedEnergyGuidanceConfig` or `list[SmoothedEnergyGuidanceConfig]`, *optional*):
|
||||
The configuration for the smoothed energy layer guidance. Can be a single `SmoothedEnergyGuidanceConfig` or
|
||||
a list of `SmoothedEnergyGuidanceConfig`. If not provided, `seg_guidance_layers` must be provided.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
@@ -86,8 +88,8 @@ class SmoothedEnergyGuidance(BaseGuidance):
|
||||
seg_blur_threshold_inf: float = 9999.0,
|
||||
seg_guidance_start: float = 0.0,
|
||||
seg_guidance_stop: float = 1.0,
|
||||
seg_guidance_layers: Optional[Union[int, List[int]]] = None,
|
||||
seg_guidance_config: Union[SmoothedEnergyGuidanceConfig, List[SmoothedEnergyGuidanceConfig]] = None,
|
||||
seg_guidance_layers: Optional[int | list[int]] = None,
|
||||
seg_guidance_config: SmoothedEnergyGuidanceConfig | list[SmoothedEnergyGuidanceConfig] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
@@ -154,7 +156,7 @@ class SmoothedEnergyGuidance(BaseGuidance):
|
||||
for hook_name in self._seg_layer_hook_names:
|
||||
registry.remove_hook(hook_name, recurse=True)
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
|
||||
if self.num_conditions == 1:
|
||||
tuple_indices = [0]
|
||||
input_predictions = ["pred_cond"]
|
||||
|
||||
@@ -12,8 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -66,7 +68,7 @@ class TangentialClassifierFreeGuidance(BaseGuidance):
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, Type
|
||||
from typing import Any, Callable, Type
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -28,7 +28,7 @@ class TransformerBlockMetadata:
|
||||
return_encoder_hidden_states_index: int = None
|
||||
|
||||
_cls: Type = None
|
||||
_cached_parameter_indices: Dict[str, int] = None
|
||||
_cached_parameter_indices: dict[str, int] = None
|
||||
|
||||
def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
|
||||
kwargs = kwargs or {}
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Type, Union
|
||||
from typing import Type
|
||||
|
||||
import torch
|
||||
|
||||
@@ -42,7 +42,7 @@ _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE = "cp_output---{}"
|
||||
# TODO(aryan): consolidate with ._helpers.TransformerBlockMetadata
|
||||
@dataclass
|
||||
class ModuleForwardMetadata:
|
||||
cached_parameter_indices: Dict[str, int] = None
|
||||
cached_parameter_indices: dict[str, int] = None
|
||||
_cls: Type = None
|
||||
|
||||
def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
|
||||
@@ -78,7 +78,7 @@ class ModuleForwardMetadata:
|
||||
def apply_context_parallel(
|
||||
module: torch.nn.Module,
|
||||
parallel_config: ContextParallelConfig,
|
||||
plan: Dict[str, ContextParallelModelPlan],
|
||||
plan: dict[str, ContextParallelModelPlan],
|
||||
) -> None:
|
||||
"""Apply context parallel on a model."""
|
||||
logger.debug(f"Applying context parallel with CP mesh: {parallel_config._mesh} and plan: {plan}")
|
||||
@@ -107,7 +107,7 @@ def apply_context_parallel(
|
||||
registry.register_hook(hook, hook_name)
|
||||
|
||||
|
||||
def remove_context_parallel(module: torch.nn.Module, plan: Dict[str, ContextParallelModelPlan]) -> None:
|
||||
def remove_context_parallel(module: torch.nn.Module, plan: dict[str, ContextParallelModelPlan]) -> None:
|
||||
for module_id, cp_model_plan in plan.items():
|
||||
submodule = _get_submodule_by_name(module, module_id)
|
||||
if not isinstance(submodule, list):
|
||||
@@ -203,12 +203,10 @@ class ContextParallelSplitHook(ModelHook):
|
||||
|
||||
def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor:
|
||||
if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims:
|
||||
logger.warning_once(
|
||||
f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions, split will not be applied."
|
||||
raise ValueError(
|
||||
f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions."
|
||||
)
|
||||
return x
|
||||
else:
|
||||
return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
|
||||
return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
|
||||
|
||||
|
||||
class ContextParallelGatherHook(ModelHook):
|
||||
@@ -274,13 +272,13 @@ class EquipartitionSharder:
|
||||
return tensor
|
||||
|
||||
|
||||
def _get_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
|
||||
def _get_submodule_by_name(model: torch.nn.Module, name: str) -> torch.nn.Module | list[torch.nn.Module]:
|
||||
if name.count("*") > 1:
|
||||
raise ValueError("Wildcard '*' can only be used once in the name")
|
||||
return _find_submodule_by_name(model, name)
|
||||
|
||||
|
||||
def _find_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
|
||||
def _find_submodule_by_name(model: torch.nn.Module, name: str) -> torch.nn.Module | list[torch.nn.Module]:
|
||||
if name == "":
|
||||
return model
|
||||
first_atom, remaining_name = name.split(".", 1) if "." in name else (name, "")
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -60,7 +60,7 @@ class FasterCacheConfig:
|
||||
Calculate the attention states every `N` iterations. If this is set to `N`, the attention computation will
|
||||
be skipped `N - 1` times (i.e., cached attention states will be reused) before computing the new attention
|
||||
states again.
|
||||
spatial_attention_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 681)`):
|
||||
spatial_attention_timestep_skip_range (`tuple[float, float]`, defaults to `(-1, 681)`):
|
||||
The timestep range within which the spatial attention computation can be skipped without a significant loss
|
||||
in quality. This is to be determined by the user based on the underlying model. The first value in the
|
||||
tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for
|
||||
@@ -68,17 +68,17 @@ class FasterCacheConfig:
|
||||
timestep 0). For the default values, this would mean that the spatial attention computation skipping will
|
||||
be applicable only after denoising timestep 681 is reached, and continue until the end of the denoising
|
||||
process.
|
||||
temporal_attention_timestep_skip_range (`Tuple[float, float]`, *optional*, defaults to `None`):
|
||||
temporal_attention_timestep_skip_range (`tuple[float, float]`, *optional*, defaults to `None`):
|
||||
The timestep range within which the temporal attention computation can be skipped without a significant
|
||||
loss in quality. This is to be determined by the user based on the underlying model. The first value in the
|
||||
tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for
|
||||
denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at
|
||||
timestep 0).
|
||||
low_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(99, 901)`):
|
||||
low_frequency_weight_update_timestep_range (`tuple[int, int]`, defaults to `(99, 901)`):
|
||||
The timestep range within which the low frequency weight scaling update is applied. The first value in the
|
||||
tuple is the lower bound and the second value is the upper bound of the timestep range. The callback
|
||||
function for the update is called only within this range.
|
||||
high_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(-1, 301)`):
|
||||
high_frequency_weight_update_timestep_range (`tuple[int, int]`, defaults to `(-1, 301)`):
|
||||
The timestep range within which the high frequency weight scaling update is applied. The first value in the
|
||||
tuple is the lower bound and the second value is the upper bound of the timestep range. The callback
|
||||
function for the update is called only within this range.
|
||||
@@ -92,15 +92,15 @@ class FasterCacheConfig:
|
||||
Process the unconditional branch every `N` iterations. If this is set to `N`, the unconditional branch
|
||||
computation will be skipped `N - 1` times (i.e., cached unconditional branch states will be reused) before
|
||||
computing the new unconditional branch states again.
|
||||
unconditional_batch_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 641)`):
|
||||
unconditional_batch_timestep_skip_range (`tuple[float, float]`, defaults to `(-1, 641)`):
|
||||
The timestep range within which the unconditional branch computation can be skipped without a significant
|
||||
loss in quality. This is to be determined by the user based on the underlying model. The first value in the
|
||||
tuple is the lower bound and the second value is the upper bound.
|
||||
spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks.*attn1", "transformer_blocks.*attn1", "single_transformer_blocks.*attn1")`):
|
||||
spatial_attention_block_identifiers (`tuple[str, ...]`, defaults to `("blocks.*attn1", "transformer_blocks.*attn1", "single_transformer_blocks.*attn1")`):
|
||||
The identifiers to match the spatial attention blocks in the model. If the name of the block contains any
|
||||
of these identifiers, FasterCache will be applied to that block. This can either be the full layer names,
|
||||
partial layer names, or regex patterns. Matching will always be done using a regex match.
|
||||
temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks.*attn1",)`):
|
||||
temporal_attention_block_identifiers (`tuple[str, ...]`, defaults to `("temporal_transformer_blocks.*attn1",)`):
|
||||
The identifiers to match the temporal attention blocks in the model. If the name of the block contains any
|
||||
of these identifiers, FasterCache will be applied to that block. This can either be the full layer names,
|
||||
partial layer names, or regex patterns. Matching will always be done using a regex match.
|
||||
@@ -123,7 +123,7 @@ class FasterCacheConfig:
|
||||
is_guidance_distilled (`bool`, defaults to `False`):
|
||||
Whether the model is guidance distilled or not. If the model is guidance distilled, FasterCache will not be
|
||||
applied at the denoiser-level to skip the unconditional branch computation (as there is none).
|
||||
_unconditional_conditional_input_kwargs_identifiers (`List[str]`, defaults to `("hidden_states", "encoder_hidden_states", "timestep", "attention_mask", "encoder_attention_mask")`):
|
||||
_unconditional_conditional_input_kwargs_identifiers (`list[str]`, defaults to `("hidden_states", "encoder_hidden_states", "timestep", "attention_mask", "encoder_attention_mask")`):
|
||||
The identifiers to match the input kwargs that contain the batchwise-concatenated unconditional and
|
||||
conditional inputs. If the name of the input kwargs contains any of these identifiers, FasterCache will
|
||||
split the inputs into unconditional and conditional branches. This must be a list of exact input kwargs
|
||||
@@ -135,12 +135,12 @@ class FasterCacheConfig:
|
||||
spatial_attention_block_skip_range: int = 2
|
||||
temporal_attention_block_skip_range: Optional[int] = None
|
||||
|
||||
spatial_attention_timestep_skip_range: Tuple[int, int] = (-1, 681)
|
||||
temporal_attention_timestep_skip_range: Tuple[int, int] = (-1, 681)
|
||||
spatial_attention_timestep_skip_range: tuple[int, int] = (-1, 681)
|
||||
temporal_attention_timestep_skip_range: tuple[int, int] = (-1, 681)
|
||||
|
||||
# Indicator functions for low/high frequency as mentioned in Equation 11 of the paper
|
||||
low_frequency_weight_update_timestep_range: Tuple[int, int] = (99, 901)
|
||||
high_frequency_weight_update_timestep_range: Tuple[int, int] = (-1, 301)
|
||||
low_frequency_weight_update_timestep_range: tuple[int, int] = (99, 901)
|
||||
high_frequency_weight_update_timestep_range: tuple[int, int] = (-1, 301)
|
||||
|
||||
# ⍺1 and ⍺2 as mentioned in Equation 11 of the paper
|
||||
alpha_low_frequency: float = 1.1
|
||||
@@ -148,10 +148,10 @@ class FasterCacheConfig:
|
||||
|
||||
# n as described in CFG-Cache explanation in the paper - dependent on the model
|
||||
unconditional_batch_skip_range: int = 5
|
||||
unconditional_batch_timestep_skip_range: Tuple[int, int] = (-1, 641)
|
||||
unconditional_batch_timestep_skip_range: tuple[int, int] = (-1, 641)
|
||||
|
||||
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
|
||||
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
|
||||
spatial_attention_block_identifiers: tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
|
||||
temporal_attention_block_identifiers: tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
|
||||
|
||||
attention_weight_callback: Callable[[torch.nn.Module], float] = None
|
||||
low_frequency_weight_callback: Callable[[torch.nn.Module], float] = None
|
||||
@@ -162,7 +162,7 @@ class FasterCacheConfig:
|
||||
|
||||
current_timestep_callback: Callable[[], int] = None
|
||||
|
||||
_unconditional_conditional_input_kwargs_identifiers: List[str] = _UNCOND_COND_INPUT_KWARGS_IDENTIFIERS
|
||||
_unconditional_conditional_input_kwargs_identifiers: list[str] = _UNCOND_COND_INPUT_KWARGS_IDENTIFIERS
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
@@ -209,7 +209,7 @@ class FasterCacheBlockState:
|
||||
def __init__(self) -> None:
|
||||
self.iteration: int = 0
|
||||
self.batch_size: int = None
|
||||
self.cache: Tuple[torch.Tensor, torch.Tensor] = None
|
||||
self.cache: tuple[torch.Tensor, torch.Tensor] = None
|
||||
|
||||
def reset(self):
|
||||
self.iteration = 0
|
||||
@@ -223,10 +223,10 @@ class FasterCacheDenoiserHook(ModelHook):
|
||||
def __init__(
|
||||
self,
|
||||
unconditional_batch_skip_range: int,
|
||||
unconditional_batch_timestep_skip_range: Tuple[int, int],
|
||||
unconditional_batch_timestep_skip_range: tuple[int, int],
|
||||
tensor_format: str,
|
||||
is_guidance_distilled: bool,
|
||||
uncond_cond_input_kwargs_identifiers: List[str],
|
||||
uncond_cond_input_kwargs_identifiers: list[str],
|
||||
current_timestep_callback: Callable[[], int],
|
||||
low_frequency_weight_callback: Callable[[torch.nn.Module], torch.Tensor],
|
||||
high_frequency_weight_callback: Callable[[torch.nn.Module], torch.Tensor],
|
||||
@@ -252,7 +252,7 @@ class FasterCacheDenoiserHook(ModelHook):
|
||||
return module
|
||||
|
||||
@staticmethod
|
||||
def _get_cond_input(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def _get_cond_input(input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Note: this method assumes that the input tensor is batchwise-concatenated with unconditional inputs
|
||||
# followed by conditional inputs.
|
||||
_, cond = input.chunk(2, dim=0)
|
||||
@@ -371,7 +371,7 @@ class FasterCacheBlockHook(ModelHook):
|
||||
def __init__(
|
||||
self,
|
||||
block_skip_range: int,
|
||||
timestep_skip_range: Tuple[int, int],
|
||||
timestep_skip_range: tuple[int, int],
|
||||
is_guidance_distilled: bool,
|
||||
weight_callback: Callable[[torch.nn.Module], float],
|
||||
current_timestep_callback: Callable[[], int],
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -53,9 +52,9 @@ class FBCSharedBlockState(BaseState):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.head_block_output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
|
||||
self.head_block_output: torch.Tensor | tuple[torch.Tensor, ...] = None
|
||||
self.head_block_residual: torch.Tensor = None
|
||||
self.tail_block_residuals: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
|
||||
self.tail_block_residuals: torch.Tensor | tuple[torch.Tensor, ...] = None
|
||||
self.should_compute: bool = True
|
||||
|
||||
def reset(self):
|
||||
|
||||
@@ -17,7 +17,7 @@ import os
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import Optional, Set
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
@@ -58,21 +58,21 @@ class GroupOffloadingConfig:
|
||||
low_cpu_mem_usage: bool
|
||||
num_blocks_per_group: Optional[int] = None
|
||||
offload_to_disk_path: Optional[str] = None
|
||||
stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None
|
||||
stream: Optional[torch.cuda.Stream | torch.Stream] = None
|
||||
|
||||
|
||||
class ModuleGroup:
|
||||
def __init__(
|
||||
self,
|
||||
modules: List[torch.nn.Module],
|
||||
modules: list[torch.nn.Module],
|
||||
offload_device: torch.device,
|
||||
onload_device: torch.device,
|
||||
offload_leader: torch.nn.Module,
|
||||
onload_leader: Optional[torch.nn.Module] = None,
|
||||
parameters: Optional[List[torch.nn.Parameter]] = None,
|
||||
buffers: Optional[List[torch.Tensor]] = None,
|
||||
parameters: Optional[list[torch.nn.Parameter]] = None,
|
||||
buffers: Optional[list[torch.Tensor]] = None,
|
||||
non_blocking: bool = False,
|
||||
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
|
||||
stream: torch.cuda.Stream | torch.Stream | None = None,
|
||||
record_stream: Optional[bool] = False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
onload_self: bool = True,
|
||||
@@ -340,7 +340,7 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
|
||||
_is_stateful = False
|
||||
|
||||
def __init__(self):
|
||||
self.execution_order: List[Tuple[str, torch.nn.Module]] = []
|
||||
self.execution_order: list[tuple[str, torch.nn.Module]] = []
|
||||
self._layer_execution_tracker_module_names = set()
|
||||
|
||||
def initialize_hook(self, module):
|
||||
@@ -444,9 +444,9 @@ class LayerExecutionTrackerHook(ModelHook):
|
||||
|
||||
def apply_group_offloading(
|
||||
module: torch.nn.Module,
|
||||
onload_device: Union[str, torch.device],
|
||||
offload_device: Union[str, torch.device] = torch.device("cpu"),
|
||||
offload_type: Union[str, GroupOffloadingType] = "block_level",
|
||||
onload_device: str | torch.device,
|
||||
offload_device: str | torch.device = torch.device("cpu"),
|
||||
offload_type: str | GroupOffloadingType = "block_level",
|
||||
num_blocks_per_group: Optional[int] = None,
|
||||
non_blocking: bool = False,
|
||||
use_stream: bool = False,
|
||||
@@ -787,7 +787,7 @@ def _apply_lazy_group_offloading_hook(
|
||||
|
||||
def _gather_parameters_with_no_group_offloading_parent(
|
||||
module: torch.nn.Module, modules_with_group_offloading: Set[str]
|
||||
) -> List[torch.nn.Parameter]:
|
||||
) -> list[torch.nn.Parameter]:
|
||||
parameters = []
|
||||
for name, parameter in module.named_parameters():
|
||||
has_parent_with_group_offloading = False
|
||||
@@ -805,7 +805,7 @@ def _gather_parameters_with_no_group_offloading_parent(
|
||||
|
||||
def _gather_buffers_with_no_group_offloading_parent(
|
||||
module: torch.nn.Module, modules_with_group_offloading: Set[str]
|
||||
) -> List[torch.Tensor]:
|
||||
) -> list[torch.Tensor]:
|
||||
buffers = []
|
||||
for name, buffer in module.named_buffers():
|
||||
has_parent_with_group_offloading = False
|
||||
@@ -821,7 +821,7 @@ def _gather_buffers_with_no_group_offloading_parent(
|
||||
return buffers
|
||||
|
||||
|
||||
def _find_parent_module_in_module_dict(name: str, module_dict: Dict[str, torch.nn.Module]) -> str:
|
||||
def _find_parent_module_in_module_dict(name: str, module_dict: dict[str, torch.nn.Module]) -> str:
|
||||
atoms = name.split(".")
|
||||
while len(atoms) > 0:
|
||||
parent_name = ".".join(atoms)
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -86,19 +86,19 @@ class ModelHook:
|
||||
"""
|
||||
return module
|
||||
|
||||
def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
|
||||
def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> tuple[tuple[Any], dict[str, Any]]:
|
||||
r"""
|
||||
Hook that is executed just before the forward method of the model.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module whose forward pass will be executed just after this event.
|
||||
args (`Tuple[Any]`):
|
||||
args (`tuple[Any]`):
|
||||
The positional arguments passed to the module.
|
||||
kwargs (`Dict[Str, Any]`):
|
||||
kwargs (`dict[Str, Any]`):
|
||||
The keyword arguments passed to the module.
|
||||
Returns:
|
||||
`Tuple[Tuple[Any], Dict[Str, Any]]`:
|
||||
`tuple[tuple[Any], dict[Str, Any]]`:
|
||||
A tuple with the treated `args` and `kwargs`.
|
||||
"""
|
||||
return args, kwargs
|
||||
@@ -168,7 +168,7 @@ class HookRegistry:
|
||||
def __init__(self, module_ref: torch.nn.Module) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.hooks: Dict[str, ModelHook] = {}
|
||||
self.hooks: dict[str, ModelHook] = {}
|
||||
|
||||
self._module_ref = module_ref
|
||||
self._hook_order = []
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -43,7 +43,7 @@ class LayerSkipConfig:
|
||||
Configuration for skipping internal transformer blocks when executing a transformer model.
|
||||
|
||||
Args:
|
||||
indices (`List[int]`):
|
||||
indices (`list[int]`):
|
||||
The indices of the layer to skip. This is typically the first layer in the transformer block.
|
||||
fqn (`str`, defaults to `"auto"`):
|
||||
The fully qualified name identifying the stack of transformer blocks. Typically, this is
|
||||
@@ -63,7 +63,7 @@ class LayerSkipConfig:
|
||||
skipped layers are fully retained, which is equivalent to not skipping any layers.
|
||||
"""
|
||||
|
||||
indices: List[int]
|
||||
indices: list[int]
|
||||
fqn: str = "auto"
|
||||
skip_attention: bool = True
|
||||
skip_attention_scores: bool = False
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from typing import Optional, Tuple, Type, Union
|
||||
from typing import Optional, Type
|
||||
|
||||
import torch
|
||||
|
||||
@@ -102,8 +102,8 @@ def apply_layerwise_casting(
|
||||
module: torch.nn.Module,
|
||||
storage_dtype: torch.dtype,
|
||||
compute_dtype: torch.dtype,
|
||||
skip_modules_pattern: Union[str, Tuple[str, ...]] = "auto",
|
||||
skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None,
|
||||
skip_modules_pattern: str | tuple[str, ...] = "auto",
|
||||
skip_modules_classes: Optional[tuple[Type[torch.nn.Module], ...]] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
r"""
|
||||
@@ -137,12 +137,12 @@ def apply_layerwise_casting(
|
||||
The dtype to cast the module to before/after the forward pass for storage.
|
||||
compute_dtype (`torch.dtype`):
|
||||
The dtype to cast the module to during the forward pass for computation.
|
||||
skip_modules_pattern (`Tuple[str, ...]`, defaults to `"auto"`):
|
||||
skip_modules_pattern (`tuple[str, ...]`, defaults to `"auto"`):
|
||||
A list of patterns to match the names of the modules to skip during the layerwise casting process. If set
|
||||
to `"auto"`, the default patterns are used. If set to `None`, no modules are skipped. If set to `None`
|
||||
alongside `skip_modules_classes` being `None`, the layerwise casting is applied directly to the module
|
||||
instead of its internal submodules.
|
||||
skip_modules_classes (`Tuple[Type[torch.nn.Module], ...]`, defaults to `None`):
|
||||
skip_modules_classes (`tuple[Type[torch.nn.Module], ...]`, defaults to `None`):
|
||||
A list of module classes to skip during the layerwise casting process.
|
||||
non_blocking (`bool`, defaults to `False`):
|
||||
If `True`, the weight casting operations are non-blocking.
|
||||
@@ -169,8 +169,8 @@ def _apply_layerwise_casting(
|
||||
module: torch.nn.Module,
|
||||
storage_dtype: torch.dtype,
|
||||
compute_dtype: torch.dtype,
|
||||
skip_modules_pattern: Optional[Tuple[str, ...]] = None,
|
||||
skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None,
|
||||
skip_modules_pattern: Optional[tuple[str, ...]] = None,
|
||||
skip_modules_classes: Optional[tuple[Type[torch.nn.Module], ...]] = None,
|
||||
non_blocking: bool = False,
|
||||
_prefix: str = "",
|
||||
) -> None:
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -54,20 +54,20 @@ class PyramidAttentionBroadcastConfig:
|
||||
The number of times a specific cross-attention broadcast is skipped before computing the attention states
|
||||
to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e.,
|
||||
old attention states will be reused) before computing the new attention states again.
|
||||
spatial_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
|
||||
spatial_attention_timestep_skip_range (`tuple[int, int]`, defaults to `(100, 800)`):
|
||||
The range of timesteps to skip in the spatial attention layer. The attention computations will be
|
||||
conditionally skipped if the current timestep is within the specified range.
|
||||
temporal_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
|
||||
temporal_attention_timestep_skip_range (`tuple[int, int]`, defaults to `(100, 800)`):
|
||||
The range of timesteps to skip in the temporal attention layer. The attention computations will be
|
||||
conditionally skipped if the current timestep is within the specified range.
|
||||
cross_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
|
||||
cross_attention_timestep_skip_range (`tuple[int, int]`, defaults to `(100, 800)`):
|
||||
The range of timesteps to skip in the cross-attention layer. The attention computations will be
|
||||
conditionally skipped if the current timestep is within the specified range.
|
||||
spatial_attention_block_identifiers (`Tuple[str, ...]`):
|
||||
spatial_attention_block_identifiers (`tuple[str, ...]`):
|
||||
The identifiers to match against the layer names to determine if the layer is a spatial attention layer.
|
||||
temporal_attention_block_identifiers (`Tuple[str, ...]`):
|
||||
temporal_attention_block_identifiers (`tuple[str, ...]`):
|
||||
The identifiers to match against the layer names to determine if the layer is a temporal attention layer.
|
||||
cross_attention_block_identifiers (`Tuple[str, ...]`):
|
||||
cross_attention_block_identifiers (`tuple[str, ...]`):
|
||||
The identifiers to match against the layer names to determine if the layer is a cross-attention layer.
|
||||
"""
|
||||
|
||||
@@ -75,13 +75,13 @@ class PyramidAttentionBroadcastConfig:
|
||||
temporal_attention_block_skip_range: Optional[int] = None
|
||||
cross_attention_block_skip_range: Optional[int] = None
|
||||
|
||||
spatial_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
|
||||
temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
|
||||
cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
|
||||
spatial_attention_timestep_skip_range: tuple[int, int] = (100, 800)
|
||||
temporal_attention_timestep_skip_range: tuple[int, int] = (100, 800)
|
||||
cross_attention_timestep_skip_range: tuple[int, int] = (100, 800)
|
||||
|
||||
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||
spatial_attention_block_identifiers: tuple[str, ...] = _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||
temporal_attention_block_identifiers: tuple[str, ...] = _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||
cross_attention_block_identifiers: tuple[str, ...] = _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||
|
||||
current_timestep_callback: Callable[[], int] = None
|
||||
|
||||
@@ -141,7 +141,7 @@ class PyramidAttentionBroadcastHook(ModelHook):
|
||||
_is_stateful = True
|
||||
|
||||
def __init__(
|
||||
self, timestep_skip_range: Tuple[int, int], block_skip_range: int, current_timestep_callback: Callable[[], int]
|
||||
self, timestep_skip_range: tuple[int, int], block_skip_range: int, current_timestep_callback: Callable[[], int]
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -288,8 +288,8 @@ def _apply_pyramid_attention_broadcast_on_attention_class(
|
||||
|
||||
|
||||
def _apply_pyramid_attention_broadcast_hook(
|
||||
module: Union[Attention, MochiAttention],
|
||||
timestep_skip_range: Tuple[int, int],
|
||||
module: Attention | MochiAttention,
|
||||
timestep_skip_range: tuple[int, int],
|
||||
block_skip_range: int,
|
||||
current_timestep_callback: Callable[[], int],
|
||||
):
|
||||
@@ -299,7 +299,7 @@ def _apply_pyramid_attention_broadcast_hook(
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module to apply Pyramid Attention Broadcast to.
|
||||
timestep_skip_range (`Tuple[int, int]`):
|
||||
timestep_skip_range (`tuple[int, int]`):
|
||||
The range of timesteps to skip in the attention layer. The attention computations will be conditionally
|
||||
skipped if the current timestep is within the specified range.
|
||||
block_skip_range (`int`):
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -35,21 +35,21 @@ class SmoothedEnergyGuidanceConfig:
|
||||
Configuration for skipping internal transformer blocks when executing a transformer model.
|
||||
|
||||
Args:
|
||||
indices (`List[int]`):
|
||||
indices (`list[int]`):
|
||||
The indices of the layer to skip. This is typically the first layer in the transformer block.
|
||||
fqn (`str`, defaults to `"auto"`):
|
||||
The fully qualified name identifying the stack of transformer blocks. Typically, this is
|
||||
`transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`.
|
||||
For automatic detection, set this to `"auto"`. "auto" only works on DiT models. For UNet models, you must
|
||||
provide the correct fqn.
|
||||
_query_proj_identifiers (`List[str]`, defaults to `None`):
|
||||
_query_proj_identifiers (`list[str]`, defaults to `None`):
|
||||
The identifiers for the query projection layers. Typically, these are `to_q`, `query`, or `q_proj`. If
|
||||
`None`, `to_q` is used by default.
|
||||
"""
|
||||
|
||||
indices: List[int]
|
||||
indices: list[int]
|
||||
fqn: str = "auto"
|
||||
_query_proj_identifiers: List[str] = None
|
||||
_query_proj_identifiers: list[str] = None
|
||||
|
||||
def to_dict(self):
|
||||
return asdict(self)
|
||||
|
||||
@@ -21,8 +21,8 @@ def _get_identifiable_transformer_blocks_in_module(module: torch.nn.Module):
|
||||
module_list_with_transformer_blocks = []
|
||||
for name, submodule in module.named_modules():
|
||||
name_endswith_identifier = any(name.endswith(identifier) for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS)
|
||||
is_modulelist = isinstance(submodule, torch.nn.ModuleList)
|
||||
if name_endswith_identifier and is_modulelist:
|
||||
is_ModuleList = isinstance(submodule, torch.nn.ModuleList)
|
||||
if name_endswith_identifier and is_ModuleList:
|
||||
module_list_with_transformer_blocks.append((name, submodule))
|
||||
return module_list_with_transformer_blocks
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
@@ -26,14 +26,9 @@ from .configuration_utils import ConfigMixin, register_to_config
|
||||
from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
|
||||
|
||||
|
||||
PipelineImageInput = Union[
|
||||
PIL.Image.Image,
|
||||
np.ndarray,
|
||||
torch.Tensor,
|
||||
List[PIL.Image.Image],
|
||||
List[np.ndarray],
|
||||
List[torch.Tensor],
|
||||
]
|
||||
PipelineImageInput = (
|
||||
PIL.Image.Image | np.ndarray | torch.Tensor | list[PIL.Image.Image] | list[np.ndarray] | list[torch.Tensor]
|
||||
)
|
||||
|
||||
PipelineDepthInput = PipelineImageInput
|
||||
|
||||
@@ -68,7 +63,7 @@ def is_valid_image_imagelist(images):
|
||||
- A list of valid images.
|
||||
|
||||
Args:
|
||||
images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, List]`):
|
||||
images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, list]`):
|
||||
The image(s) to check. Can be a batch of images (4D tensor/array), a single image, or a list of valid
|
||||
images.
|
||||
|
||||
@@ -131,7 +126,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
|
||||
def numpy_to_pil(images: np.ndarray) -> list[PIL.Image.Image]:
|
||||
r"""
|
||||
Convert a numpy image or a batch of images to a PIL image.
|
||||
|
||||
@@ -140,7 +135,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
The image array to convert to PIL format.
|
||||
|
||||
Returns:
|
||||
`List[PIL.Image.Image]`:
|
||||
`list[PIL.Image.Image]`:
|
||||
A list of PIL images.
|
||||
"""
|
||||
if images.ndim == 3:
|
||||
@@ -155,12 +150,12 @@ class VaeImageProcessor(ConfigMixin):
|
||||
return pil_images
|
||||
|
||||
@staticmethod
|
||||
def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
|
||||
def pil_to_numpy(images: list[PIL.Image.Image] | PIL.Image.Image) -> np.ndarray:
|
||||
r"""
|
||||
Convert a PIL image or a list of PIL images to NumPy arrays.
|
||||
|
||||
Args:
|
||||
images (`PIL.Image.Image` or `List[PIL.Image.Image]`):
|
||||
images (`PIL.Image.Image` or `list[PIL.Image.Image]`):
|
||||
The PIL image or list of images to convert to NumPy format.
|
||||
|
||||
Returns:
|
||||
@@ -210,7 +205,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
return images
|
||||
|
||||
@staticmethod
|
||||
def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
|
||||
def normalize(images: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor:
|
||||
r"""
|
||||
Normalize an image array to [-1,1].
|
||||
|
||||
@@ -225,7 +220,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
return 2.0 * images - 1.0
|
||||
|
||||
@staticmethod
|
||||
def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
|
||||
def denormalize(images: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor:
|
||||
r"""
|
||||
Denormalize an image array to [0,1].
|
||||
|
||||
@@ -467,11 +462,11 @@ class VaeImageProcessor(ConfigMixin):
|
||||
|
||||
def resize(
|
||||
self,
|
||||
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
||||
image: PIL.Image.Image | np.ndarray | torch.Tensor,
|
||||
height: int,
|
||||
width: int,
|
||||
resize_mode: str = "default", # "default", "fill", "crop"
|
||||
) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
|
||||
) -> PIL.Image.Image | np.ndarray | torch.Tensor:
|
||||
"""
|
||||
Resize image.
|
||||
|
||||
@@ -544,7 +539,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
return image
|
||||
|
||||
def _denormalize_conditionally(
|
||||
self, images: torch.Tensor, do_denormalize: Optional[List[bool]] = None
|
||||
self, images: torch.Tensor, do_denormalize: Optional[list[bool]] = None
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Denormalize a batch of images based on a condition list.
|
||||
@@ -552,7 +547,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
Args:
|
||||
images (`torch.Tensor`):
|
||||
The input image tensor.
|
||||
do_denormalize (`Optional[List[bool]`, *optional*, defaults to `None`):
|
||||
do_denormalize (`Optional[list[bool]`, *optional*, defaults to `None`):
|
||||
A list of booleans indicating whether to denormalize each image in the batch. If `None`, will use the
|
||||
value of `do_normalize` in the `VaeImageProcessor` config.
|
||||
"""
|
||||
@@ -565,10 +560,10 @@ class VaeImageProcessor(ConfigMixin):
|
||||
|
||||
def get_default_height_width(
|
||||
self,
|
||||
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
||||
image: PIL.Image.Image | np.ndarray | torch.Tensor,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
) -> Tuple[int, int]:
|
||||
) -> tuple[int, int]:
|
||||
r"""
|
||||
Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`.
|
||||
|
||||
@@ -583,7 +578,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
The width of the preprocessed image. If `None`, the width of the `image` input will be used.
|
||||
|
||||
Returns:
|
||||
`Tuple[int, int]`:
|
||||
`tuple[int, int]`:
|
||||
A tuple containing the height and width, both resized to the nearest integer multiple of
|
||||
`vae_scale_factor`.
|
||||
"""
|
||||
@@ -616,7 +611,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
resize_mode: str = "default", # "default", "fill", "crop"
|
||||
crops_coords: Optional[Tuple[int, int, int, int]] = None,
|
||||
crops_coords: Optional[tuple[int, int, int, int]] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Preprocess the image input.
|
||||
@@ -638,7 +633,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
image to fit within the specified width and height, maintaining the aspect ratio, and then center the
|
||||
image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
|
||||
supported for PIL image input.
|
||||
crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
|
||||
crops_coords (`list[tuple[int, int, int, int]]`, *optional*, defaults to `None`):
|
||||
The crop coordinates for each image in the batch. If `None`, will not crop the image.
|
||||
|
||||
Returns:
|
||||
@@ -745,8 +740,8 @@ class VaeImageProcessor(ConfigMixin):
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
output_type: str = "pil",
|
||||
do_denormalize: Optional[List[bool]] = None,
|
||||
) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
|
||||
do_denormalize: Optional[list[bool]] = None,
|
||||
) -> PIL.Image.Image | np.ndarray | torch.Tensor:
|
||||
"""
|
||||
Postprocess the image output from tensor to `output_type`.
|
||||
|
||||
@@ -755,7 +750,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
The image input, should be a pytorch tensor with shape `B x C x H x W`.
|
||||
output_type (`str`, *optional*, defaults to `pil`):
|
||||
The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
|
||||
do_denormalize (`List[bool]`, *optional*, defaults to `None`):
|
||||
do_denormalize (`list[bool]`, *optional*, defaults to `None`):
|
||||
Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
|
||||
`VaeImageProcessor` config.
|
||||
|
||||
@@ -796,7 +791,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
mask: PIL.Image.Image,
|
||||
init_image: PIL.Image.Image,
|
||||
image: PIL.Image.Image,
|
||||
crop_coords: Optional[Tuple[int, int, int, int]] = None,
|
||||
crop_coords: Optional[tuple[int, int, int, int]] = None,
|
||||
) -> PIL.Image.Image:
|
||||
r"""
|
||||
Applies an overlay of the mask and the inpainted image on the original image.
|
||||
@@ -808,7 +803,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
The original image to which the overlay is applied.
|
||||
image (`PIL.Image.Image`):
|
||||
The image to overlay onto the original.
|
||||
crop_coords (`Tuple[int, int, int, int]`, *optional*):
|
||||
crop_coords (`tuple[int, int, int, int]`, *optional*):
|
||||
Coordinates to crop the image. If provided, the image will be cropped accordingly.
|
||||
|
||||
Returns:
|
||||
@@ -891,7 +886,7 @@ class InpaintProcessor(ConfigMixin):
|
||||
height: int = None,
|
||||
width: int = None,
|
||||
padding_mask_crop: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Preprocess the image and mask.
|
||||
"""
|
||||
@@ -946,8 +941,8 @@ class InpaintProcessor(ConfigMixin):
|
||||
output_type: str = "pil",
|
||||
original_image: Optional[PIL.Image.Image] = None,
|
||||
original_mask: Optional[PIL.Image.Image] = None,
|
||||
crops_coords: Optional[Tuple[int, int, int, int]] = None,
|
||||
) -> Tuple[PIL.Image.Image, PIL.Image.Image]:
|
||||
crops_coords: Optional[tuple[int, int, int, int]] = None,
|
||||
) -> tuple[PIL.Image.Image, PIL.Image.Image]:
|
||||
"""
|
||||
Postprocess the image, optionally apply mask overlay
|
||||
"""
|
||||
@@ -998,7 +993,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
super().__init__()
|
||||
|
||||
@staticmethod
|
||||
def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
|
||||
def numpy_to_pil(images: np.ndarray) -> list[PIL.Image.Image]:
|
||||
r"""
|
||||
Convert a NumPy image or a batch of images to a list of PIL images.
|
||||
|
||||
@@ -1007,7 +1002,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
The input NumPy array of images, which can be a single image or a batch.
|
||||
|
||||
Returns:
|
||||
`List[PIL.Image.Image]`:
|
||||
`list[PIL.Image.Image]`:
|
||||
A list of PIL images converted from the input NumPy array.
|
||||
"""
|
||||
if images.ndim == 3:
|
||||
@@ -1022,12 +1017,12 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
return pil_images
|
||||
|
||||
@staticmethod
|
||||
def depth_pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
|
||||
def depth_pil_to_numpy(images: list[PIL.Image.Image] | PIL.Image.Image) -> np.ndarray:
|
||||
r"""
|
||||
Convert a PIL image or a list of PIL images to NumPy arrays.
|
||||
|
||||
Args:
|
||||
images (`Union[List[PIL.Image.Image], PIL.Image.Image]`):
|
||||
images (`Union[list[PIL.Image.Image], PIL.Image.Image]`):
|
||||
The input image or list of images to be converted.
|
||||
|
||||
Returns:
|
||||
@@ -1042,44 +1037,21 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
return images
|
||||
|
||||
@staticmethod
|
||||
def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
|
||||
def rgblike_to_depthmap(image: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor:
|
||||
r"""
|
||||
Convert an RGB-like depth image to a depth map.
|
||||
|
||||
Args:
|
||||
image (`Union[np.ndarray, torch.Tensor]`):
|
||||
The RGB-like depth image to convert.
|
||||
|
||||
Returns:
|
||||
`Union[np.ndarray, torch.Tensor]`:
|
||||
The corresponding depth map.
|
||||
"""
|
||||
# 1. Cast the tensor to a larger integer type (e.g., int32)
|
||||
# to safely perform the multiplication by 256.
|
||||
# 2. Perform the 16-bit combination: High-byte * 256 + Low-byte.
|
||||
# 3. Cast the final result to the desired depth map type (uint16) if needed
|
||||
# before returning, though leaving it as int32/int64 is often safer
|
||||
# for return value from a library function.
|
||||
return image[:, :, 1] * 2**8 + image[:, :, 2]
|
||||
|
||||
if isinstance(image, torch.Tensor):
|
||||
# Cast to a safe dtype (e.g., int32 or int64) for the calculation
|
||||
original_dtype = image.dtype
|
||||
image_safe = image.to(torch.int32)
|
||||
|
||||
# Calculate the depth map
|
||||
depth_map = image_safe[:, :, 1] * 256 + image_safe[:, :, 2]
|
||||
|
||||
# You may want to cast the final result to uint16, but casting to a
|
||||
# larger int type (like int32) is sufficient to fix the overflow.
|
||||
# depth_map = depth_map.to(torch.uint16) # Uncomment if uint16 is strictly required
|
||||
return depth_map.to(original_dtype)
|
||||
|
||||
elif isinstance(image, np.ndarray):
|
||||
# NumPy equivalent: Cast to a safe dtype (e.g., np.int32)
|
||||
original_dtype = image.dtype
|
||||
image_safe = image.astype(np.int32)
|
||||
|
||||
# Calculate the depth map
|
||||
depth_map = image_safe[:, :, 1] * 256 + image_safe[:, :, 2]
|
||||
|
||||
# depth_map = depth_map.astype(np.uint16) # Uncomment if uint16 is strictly required
|
||||
return depth_map.astype(original_dtype)
|
||||
else:
|
||||
raise TypeError("Input image must be a torch.Tensor or np.ndarray")
|
||||
|
||||
def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]:
|
||||
def numpy_to_depth(self, images: np.ndarray) -> list[PIL.Image.Image]:
|
||||
r"""
|
||||
Convert a NumPy depth image or a batch of images to a list of PIL images.
|
||||
|
||||
@@ -1088,7 +1060,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
The input NumPy array of depth images, which can be a single image or a batch.
|
||||
|
||||
Returns:
|
||||
`List[PIL.Image.Image]`:
|
||||
`list[PIL.Image.Image]`:
|
||||
A list of PIL images converted from the input NumPy depth images.
|
||||
"""
|
||||
if images.ndim == 3:
|
||||
@@ -1111,8 +1083,8 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
output_type: str = "pil",
|
||||
do_denormalize: Optional[List[bool]] = None,
|
||||
) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
|
||||
do_denormalize: Optional[list[bool]] = None,
|
||||
) -> PIL.Image.Image | np.ndarray | torch.Tensor:
|
||||
"""
|
||||
Postprocess the image output from tensor to `output_type`.
|
||||
|
||||
@@ -1121,7 +1093,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
The image input, should be a pytorch tensor with shape `B x C x H x W`.
|
||||
output_type (`str`, *optional*, defaults to `pil`):
|
||||
The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
|
||||
do_denormalize (`List[bool]`, *optional*, defaults to `None`):
|
||||
do_denormalize (`list[bool]`, *optional*, defaults to `None`):
|
||||
Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
|
||||
`VaeImageProcessor` config.
|
||||
|
||||
@@ -1159,8 +1131,8 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
rgb: Union[torch.Tensor, PIL.Image.Image, np.ndarray],
|
||||
depth: Union[torch.Tensor, PIL.Image.Image, np.ndarray],
|
||||
rgb: torch.Tensor | PIL.Image.Image | np.ndarray,
|
||||
depth: torch.Tensor | PIL.Image.Image | np.ndarray,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
target_res: Optional[int] = None,
|
||||
@@ -1181,7 +1153,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
Target resolution for resizing the images. If specified, overrides height and width.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, torch.Tensor]`:
|
||||
`tuple[torch.Tensor, torch.Tensor]`:
|
||||
A tuple containing the processed RGB and depth images as PyTorch tensors.
|
||||
"""
|
||||
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
|
||||
@@ -1419,7 +1391,7 @@ class PixArtImageProcessor(VaeImageProcessor):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
|
||||
def classify_height_width_bin(height: int, width: int, ratios: dict) -> tuple[int, int]:
|
||||
r"""
|
||||
Returns the binned height and width based on the aspect ratio.
|
||||
|
||||
@@ -1429,7 +1401,7 @@ class PixArtImageProcessor(VaeImageProcessor):
|
||||
ratios (`dict`): A dictionary where keys are aspect ratios and values are tuples of (height, width).
|
||||
|
||||
Returns:
|
||||
`Tuple[int, int]`: The closest binned height and width.
|
||||
`tuple[int, int]`: The closest binned height and width.
|
||||
"""
|
||||
ar = float(height / width)
|
||||
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -57,15 +57,15 @@ class IPAdapterMixin:
|
||||
@validate_hf_hub_args
|
||||
def load_ip_adapter(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
|
||||
subfolder: Union[str, List[str]],
|
||||
weight_name: Union[str, List[str]],
|
||||
pretrained_model_name_or_path_or_dict: str | list[str] | dict[str, torch.Tensor],
|
||||
subfolder: str | list[str],
|
||||
weight_name: str | list[str],
|
||||
image_encoder_folder: Optional[str] = "image_encoder",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
|
||||
pretrained_model_name_or_path_or_dict (`str` or `list[str]` or `os.PathLike` or `list[os.PathLike]` or `dict` or `list[dict]`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
@@ -74,10 +74,10 @@ class IPAdapterMixin:
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
subfolder (`str` or `List[str]`):
|
||||
subfolder (`str` or `list[str]`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
|
||||
list is passed, it should have the same length as `weight_name`.
|
||||
weight_name (`str` or `List[str]`):
|
||||
weight_name (`str` or `list[str]`):
|
||||
The name of the weight file to load. If a list is passed, it should have the same length as
|
||||
`subfolder`.
|
||||
image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
|
||||
@@ -94,7 +94,7 @@ class IPAdapterMixin:
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
proxies (`dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
@@ -358,14 +358,14 @@ class ModularIPAdapterMixin:
|
||||
@validate_hf_hub_args
|
||||
def load_ip_adapter(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
|
||||
subfolder: Union[str, List[str]],
|
||||
weight_name: Union[str, List[str]],
|
||||
pretrained_model_name_or_path_or_dict: str | list[str] | dict[str, torch.Tensor],
|
||||
subfolder: str | list[str],
|
||||
weight_name: str | list[str],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
|
||||
pretrained_model_name_or_path_or_dict (`str` or `list[str]` or `os.PathLike` or `list[os.PathLike]` or `dict` or `list[dict]`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
@@ -374,10 +374,10 @@ class ModularIPAdapterMixin:
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
subfolder (`str` or `List[str]`):
|
||||
subfolder (`str` or `list[str]`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
|
||||
list is passed, it should have the same length as `weight_name`.
|
||||
weight_name (`str` or `List[str]`):
|
||||
weight_name (`str` or `list[str]`):
|
||||
The name of the weight file to load. If a list is passed, it should have the same length as
|
||||
`subfolder`.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
@@ -387,7 +387,7 @@ class ModularIPAdapterMixin:
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
proxies (`dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
@@ -608,9 +608,9 @@ class FluxIPAdapterMixin:
|
||||
@validate_hf_hub_args
|
||||
def load_ip_adapter(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
|
||||
weight_name: Union[str, List[str]],
|
||||
subfolder: Optional[Union[str, List[str]]] = "",
|
||||
pretrained_model_name_or_path_or_dict: str | list[str] | dict[str, torch.Tensor],
|
||||
weight_name: str | list[str],
|
||||
subfolder: Optional[str | list[str]] = "",
|
||||
image_encoder_pretrained_model_name_or_path: Optional[str] = "image_encoder",
|
||||
image_encoder_subfolder: Optional[str] = "",
|
||||
image_encoder_dtype: torch.dtype = torch.float16,
|
||||
@@ -618,7 +618,7 @@ class FluxIPAdapterMixin:
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
|
||||
pretrained_model_name_or_path_or_dict (`str` or `list[str]` or `os.PathLike` or `list[os.PathLike]` or `dict` or `list[dict]`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
@@ -627,10 +627,10 @@ class FluxIPAdapterMixin:
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
subfolder (`str` or `List[str]`):
|
||||
subfolder (`str` or `list[str]`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
|
||||
list is passed, it should have the same length as `weight_name`.
|
||||
weight_name (`str` or `List[str]`):
|
||||
weight_name (`str` or `list[str]`):
|
||||
The name of the weight file to load. If a list is passed, it should have the same length as
|
||||
`weight_name`.
|
||||
image_encoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `./image_encoder`):
|
||||
@@ -647,7 +647,7 @@ class FluxIPAdapterMixin:
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
proxies (`dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
@@ -797,13 +797,13 @@ class FluxIPAdapterMixin:
|
||||
# load ip-adapter into transformer
|
||||
self.transformer._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||
|
||||
def set_ip_adapter_scale(self, scale: Union[float, List[float], List[List[float]]]):
|
||||
def set_ip_adapter_scale(self, scale: float | list[float] | list[list[float]]):
|
||||
"""
|
||||
Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
|
||||
granular control over each IP-Adapter behavior. A config can be a float or a list.
|
||||
|
||||
`float` is converted to list and repeated for the number of blocks and the number of IP adapters. `List[float]`
|
||||
length match the number of blocks, it is repeated for each IP adapter. `List[List[float]]` must match the
|
||||
`float` is converted to list and repeated for the number of blocks and the number of IP adapters. `list[float]`
|
||||
length match the number of blocks, it is repeated for each IP adapter. `list[list[float]]` must match the
|
||||
number of IP adapters and each must match the number of blocks.
|
||||
|
||||
Example:
|
||||
@@ -823,18 +823,18 @@ class FluxIPAdapterMixin:
|
||||
```
|
||||
"""
|
||||
|
||||
scale_type = Union[int, float]
|
||||
scale_type = int | float
|
||||
num_ip_adapters = self.transformer.encoder_hid_proj.num_ip_adapters
|
||||
num_layers = self.transformer.config.num_layers
|
||||
|
||||
# Single value for all layers of all IP-Adapters
|
||||
if isinstance(scale, scale_type):
|
||||
scale = [scale for _ in range(num_ip_adapters)]
|
||||
# List of per-layer scales for a single IP-Adapter
|
||||
elif _is_valid_type(scale, List[scale_type]) and num_ip_adapters == 1:
|
||||
# list of per-layer scales for a single IP-Adapter
|
||||
elif _is_valid_type(scale, list[scale_type]) and num_ip_adapters == 1:
|
||||
scale = [scale]
|
||||
# Invalid scale type
|
||||
elif not _is_valid_type(scale, List[Union[scale_type, List[scale_type]]]):
|
||||
elif not _is_valid_type(scale, list[scale_type | list[scale_type]]):
|
||||
raise TypeError(f"Unexpected type {_get_detailed_type(scale)} for scale.")
|
||||
|
||||
if len(scale) != num_ip_adapters:
|
||||
@@ -918,7 +918,7 @@ class SD3IPAdapterMixin:
|
||||
@validate_hf_hub_args
|
||||
def load_ip_adapter(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
||||
weight_name: str = "ip-adapter.safetensors",
|
||||
subfolder: Optional[str] = None,
|
||||
image_encoder_folder: Optional[str] = "image_encoder",
|
||||
@@ -953,7 +953,7 @@ class SD3IPAdapterMixin:
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
proxies (`dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
|
||||
@@ -17,7 +17,7 @@ import inspect
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
@@ -77,7 +77,7 @@ def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adap
|
||||
Controls how much to influence the outputs with the LoRA parameters.
|
||||
safe_fusing (`bool`, defaults to `False`):
|
||||
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
||||
adapter_names (`List[str]` or `str`):
|
||||
adapter_names (`list[str]` or `str`):
|
||||
The names of the adapters to use.
|
||||
"""
|
||||
merge_kwargs = {"safe_merge": safe_fusing}
|
||||
@@ -116,20 +116,20 @@ def unfuse_text_encoder_lora(text_encoder):
|
||||
|
||||
|
||||
def set_adapters_for_text_encoder(
|
||||
adapter_names: Union[List[str], str],
|
||||
adapter_names: list[str] | str,
|
||||
text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
|
||||
text_encoder_weights: Optional[Union[float, List[float], List[None]]] = None,
|
||||
text_encoder_weights: Optional[float | list[float] | list[None]] = None,
|
||||
):
|
||||
"""
|
||||
Sets the adapter layers for the text encoder.
|
||||
|
||||
Args:
|
||||
adapter_names (`List[str]` or `str`):
|
||||
adapter_names (`list[str]` or `str`):
|
||||
The names of the adapters to use.
|
||||
text_encoder (`torch.nn.Module`, *optional*):
|
||||
The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
|
||||
attribute.
|
||||
text_encoder_weights (`List[float]`, *optional*):
|
||||
text_encoder_weights (`list[float]`, *optional*):
|
||||
The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters.
|
||||
"""
|
||||
if text_encoder is None:
|
||||
@@ -535,10 +535,10 @@ class LoraBaseMixin:
|
||||
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = [],
|
||||
components: list[str] = [],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
adapter_names: Optional[list[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -547,12 +547,12 @@ class LoraBaseMixin:
|
||||
> [!WARNING] > This is an experimental API.
|
||||
|
||||
Args:
|
||||
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
||||
components: (`list[str]`): list of LoRA-injectable components to fuse the LoRAs into.
|
||||
lora_scale (`float`, defaults to 1.0):
|
||||
Controls how much to influence the outputs with the LoRA parameters.
|
||||
safe_fusing (`bool`, defaults to `False`):
|
||||
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
adapter_names (`list[str]`, *optional*):
|
||||
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
||||
|
||||
Example:
|
||||
@@ -619,7 +619,7 @@ class LoraBaseMixin:
|
||||
|
||||
self._merged_adapters = self._merged_adapters | merged_adapter_names
|
||||
|
||||
def unfuse_lora(self, components: List[str] = [], **kwargs):
|
||||
def unfuse_lora(self, components: list[str] = [], **kwargs):
|
||||
r"""
|
||||
Reverses the effect of
|
||||
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
||||
@@ -627,7 +627,7 @@ class LoraBaseMixin:
|
||||
> [!WARNING] > This is an experimental API.
|
||||
|
||||
Args:
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
components (`list[str]`): list of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
unfuse_text_encoder (`bool`, defaults to `True`):
|
||||
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
||||
@@ -674,16 +674,16 @@ class LoraBaseMixin:
|
||||
|
||||
def set_adapters(
|
||||
self,
|
||||
adapter_names: Union[List[str], str],
|
||||
adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
|
||||
adapter_names: list[str] | str,
|
||||
adapter_weights: Optional[float | Dict | list[float] | list[Dict]] = None,
|
||||
):
|
||||
"""
|
||||
Set the currently active adapters for use in the pipeline.
|
||||
|
||||
Args:
|
||||
adapter_names (`List[str]` or `str`):
|
||||
adapter_names (`list[str]` or `str`):
|
||||
The names of the adapters to use.
|
||||
adapter_weights (`Union[List[float], float]`, *optional*):
|
||||
adapter_weights (`Union[list[float], float]`, *optional*):
|
||||
The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the
|
||||
adapters.
|
||||
|
||||
@@ -835,12 +835,12 @@ class LoraBaseMixin:
|
||||
elif issubclass(model.__class__, PreTrainedModel):
|
||||
enable_lora_for_text_encoder(model)
|
||||
|
||||
def delete_adapters(self, adapter_names: Union[List[str], str]):
|
||||
def delete_adapters(self, adapter_names: list[str] | str):
|
||||
"""
|
||||
Delete an adapter's LoRA layers from the pipeline.
|
||||
|
||||
Args:
|
||||
adapter_names (`Union[List[str], str]`):
|
||||
adapter_names (`Union[list[str], str]`):
|
||||
The names of the adapters to delete.
|
||||
|
||||
Example:
|
||||
@@ -873,7 +873,7 @@ class LoraBaseMixin:
|
||||
for adapter_name in adapter_names:
|
||||
delete_adapter_layers(model, adapter_name)
|
||||
|
||||
def get_active_adapters(self) -> List[str]:
|
||||
def get_active_adapters(self) -> list[str]:
|
||||
"""
|
||||
Gets the list of the current active adapters.
|
||||
|
||||
@@ -906,7 +906,7 @@ class LoraBaseMixin:
|
||||
|
||||
return active_adapters
|
||||
|
||||
def get_list_adapters(self) -> Dict[str, List[str]]:
|
||||
def get_list_adapters(self) -> dict[str, list[str]]:
|
||||
"""
|
||||
Gets the current list of all available adapters in the pipeline.
|
||||
"""
|
||||
@@ -928,7 +928,7 @@ class LoraBaseMixin:
|
||||
|
||||
return set_adapters
|
||||
|
||||
def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, str, int]) -> None:
|
||||
def set_lora_device(self, adapter_names: list[str], device: torch.device | str | int) -> None:
|
||||
"""
|
||||
Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
|
||||
you want to load multiple adapters and free some GPU memory.
|
||||
@@ -955,8 +955,8 @@ class LoraBaseMixin:
|
||||
```
|
||||
|
||||
Args:
|
||||
adapter_names (`List[str]`):
|
||||
List of adapters to send device to.
|
||||
adapter_names (`list[str]`):
|
||||
list of adapters to send device to.
|
||||
device (`Union[torch.device, str, int]`):
|
||||
Device to send the adapters to. Can be either a torch device, a str or an integer.
|
||||
"""
|
||||
@@ -1007,7 +1007,7 @@ class LoraBaseMixin:
|
||||
|
||||
@staticmethod
|
||||
def write_lora_layers(
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
state_dict: dict[str, torch.Tensor],
|
||||
save_directory: str,
|
||||
is_main_process: bool,
|
||||
weight_name: str,
|
||||
@@ -1059,9 +1059,9 @@ class LoraBaseMixin:
|
||||
@classmethod
|
||||
def _save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
lora_layers: Dict[str, Dict[str, Union[torch.nn.Module, torch.Tensor]]],
|
||||
lora_metadata: Dict[str, Optional[dict]],
|
||||
save_directory: str | os.PathLike,
|
||||
lora_layers: dict[str, dict[str, torch.nn.Module | torch.Tensor]],
|
||||
lora_metadata: dict[str, Optional[dict]],
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
@@ -1021,7 +1020,7 @@ def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def _custom_replace(key: str, substrings: List[str]) -> str:
|
||||
def _custom_replace(key: str, substrings: list[str]) -> str:
|
||||
# Replaces the "."s with "_"s upto the `substrings`.
|
||||
# Example:
|
||||
# lora_unet.foo.bar.lora_A.weight -> lora_unet_foo_bar.lora_A.weight
|
||||
@@ -2213,10 +2212,6 @@ def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
|
||||
|
||||
state_dict = {convert_key(k): v for k, v in state_dict.items()}
|
||||
|
||||
has_default = any("default." in k for k in state_dict)
|
||||
if has_default:
|
||||
state_dict = {k.replace("default.", ""): v for k, v in state_dict.items()}
|
||||
|
||||
converted_state_dict = {}
|
||||
all_keys = list(state_dict.keys())
|
||||
down_key = ".lora_down.weight"
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -17,7 +17,7 @@ import json
|
||||
import os
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Literal, Optional, Union
|
||||
from typing import Dict, Literal, Optional
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
@@ -113,7 +113,7 @@ class PeftAdapterMixin:
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
proxies (`dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
@@ -127,7 +127,7 @@ class PeftAdapterMixin:
|
||||
allowed by Git.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||
network_alphas (`Dict[str, float]`):
|
||||
network_alphas (`dict[str, float]`):
|
||||
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
||||
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
||||
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
||||
@@ -447,16 +447,16 @@ class PeftAdapterMixin:
|
||||
|
||||
def set_adapters(
|
||||
self,
|
||||
adapter_names: Union[List[str], str],
|
||||
weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None,
|
||||
adapter_names: list[str] | str,
|
||||
weights: Optional[float | Dict | list[float] | list[Dict] | list[None]] = None,
|
||||
):
|
||||
"""
|
||||
Set the currently active adapters for use in the diffusion network (e.g. unet, transformer, etc.).
|
||||
|
||||
Args:
|
||||
adapter_names (`List[str]` or `str`):
|
||||
adapter_names (`list[str]` or `str`):
|
||||
The names of the adapters to use.
|
||||
adapter_weights (`Union[List[float], float]`, *optional*):
|
||||
adapter_weights (`Union[list[float], float]`, *optional*):
|
||||
The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the
|
||||
adapters.
|
||||
|
||||
@@ -539,7 +539,7 @@ class PeftAdapterMixin:
|
||||
inject_adapter_in_model(adapter_config, self, adapter_name)
|
||||
self.set_adapter(adapter_name)
|
||||
|
||||
def set_adapter(self, adapter_name: Union[str, List[str]]) -> None:
|
||||
def set_adapter(self, adapter_name: str | list[str]) -> None:
|
||||
"""
|
||||
Sets a specific adapter by forcing the model to only use that adapter and disables the other adapters.
|
||||
|
||||
@@ -547,7 +547,7 @@ class PeftAdapterMixin:
|
||||
[documentation](https://huggingface.co/docs/peft).
|
||||
|
||||
Args:
|
||||
adapter_name (Union[str, List[str]])):
|
||||
adapter_name (Union[str, list[str]])):
|
||||
The list of adapters to set or the adapter name in the case of a single adapter.
|
||||
"""
|
||||
check_peft_version(min_version=MIN_PEFT_VERSION)
|
||||
@@ -633,7 +633,7 @@ class PeftAdapterMixin:
|
||||
# support for older PEFT versions
|
||||
module.disable_adapters = False
|
||||
|
||||
def active_adapters(self) -> List[str]:
|
||||
def active_adapters(self) -> list[str]:
|
||||
"""
|
||||
Gets the current list of active adapters of the model.
|
||||
|
||||
@@ -756,12 +756,12 @@ class PeftAdapterMixin:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
set_adapter_layers(self, enabled=True)
|
||||
|
||||
def delete_adapters(self, adapter_names: Union[List[str], str]):
|
||||
def delete_adapters(self, adapter_names: list[str] | str):
|
||||
"""
|
||||
Delete an adapter's LoRA layers from the underlying model.
|
||||
|
||||
Args:
|
||||
adapter_names (`Union[List[str], str]`):
|
||||
adapter_names (`Union[list[str], str]`):
|
||||
The names (single string or list of strings) of the adapter to delete.
|
||||
|
||||
Example:
|
||||
|
||||
@@ -290,7 +290,7 @@ class FromSingleFileMixin:
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
proxies (`dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
|
||||
@@ -229,7 +229,7 @@ class FromOriginalModelMixin:
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
proxies (`dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
@@ -112,7 +112,7 @@ class TextualInversionLoaderMixin:
|
||||
Load Textual Inversion tokens and embeddings to the tokenizer and text encoder.
|
||||
"""
|
||||
|
||||
def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): # noqa: F821
|
||||
def maybe_convert_prompt(self, prompt: str | list[str], tokenizer: "PreTrainedTokenizer"): # noqa: F821
|
||||
r"""
|
||||
Processes prompts that include a special token corresponding to a multi-vector textual inversion embedding to
|
||||
be replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
|
||||
@@ -127,14 +127,14 @@ class TextualInversionLoaderMixin:
|
||||
Returns:
|
||||
`str` or list of `str`: The converted prompt
|
||||
"""
|
||||
if not isinstance(prompt, List):
|
||||
if not isinstance(prompt, list):
|
||||
prompts = [prompt]
|
||||
else:
|
||||
prompts = prompt
|
||||
|
||||
prompts = [self._maybe_convert_prompt(p, tokenizer) for p in prompts]
|
||||
|
||||
if not isinstance(prompt, List):
|
||||
if not isinstance(prompt, list):
|
||||
return prompts[0]
|
||||
|
||||
return prompts
|
||||
@@ -263,8 +263,8 @@ class TextualInversionLoaderMixin:
|
||||
@validate_hf_hub_args
|
||||
def load_textual_inversion(
|
||||
self,
|
||||
pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
|
||||
token: Optional[Union[str, List[str]]] = None,
|
||||
pretrained_model_name_or_path: str | list[str] | dict[str, torch.Tensor] | list[dict[str, torch.Tensor]],
|
||||
token: Optional[str | list[str]] = None,
|
||||
tokenizer: Optional["PreTrainedTokenizer"] = None, # noqa: F821
|
||||
text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
|
||||
**kwargs,
|
||||
@@ -274,7 +274,7 @@ class TextualInversionLoaderMixin:
|
||||
Automatic1111 formats are supported).
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` or `Dict` or `List[Dict]`):
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike` or `list[str or os.PathLike]` or `Dict` or `list[Dict]`):
|
||||
Can be either one of the following or a list of them:
|
||||
|
||||
- A string, the *model id* (for example `sd-concepts-library/low-poly-hd-logos-icons`) of a
|
||||
@@ -285,7 +285,7 @@ class TextualInversionLoaderMixin:
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
|
||||
token (`str` or `List[str]`, *optional*):
|
||||
token (`str` or `list[str]`, *optional*):
|
||||
Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a
|
||||
list, then `token` must also be a list of equal length.
|
||||
text_encoder ([`~transformers.CLIPTextModel`], *optional*):
|
||||
@@ -306,7 +306,7 @@ class TextualInversionLoaderMixin:
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
proxies (`dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
@@ -458,7 +458,7 @@ class TextualInversionLoaderMixin:
|
||||
|
||||
def unload_textual_inversion(
|
||||
self,
|
||||
tokens: Optional[Union[str, List[str]]] = None,
|
||||
tokens: Optional[str | list[str]] = None,
|
||||
tokenizer: Optional["PreTrainedTokenizer"] = None,
|
||||
text_encoder: Optional["PreTrainedModel"] = None,
|
||||
):
|
||||
|
||||
@@ -15,7 +15,7 @@ import os
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Union
|
||||
from typing import Callable
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
@@ -66,7 +66,7 @@ class UNet2DConditionLoadersMixin:
|
||||
unet_name = UNET_NAME
|
||||
|
||||
@validate_hf_hub_args
|
||||
def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
|
||||
def load_attn_procs(self, pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], **kwargs):
|
||||
r"""
|
||||
Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be
|
||||
defined in
|
||||
@@ -92,7 +92,7 @@ class UNet2DConditionLoadersMixin:
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
proxies (`dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
@@ -106,7 +106,7 @@ class UNet2DConditionLoadersMixin:
|
||||
allowed by Git.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||
network_alphas (`Dict[str, float]`):
|
||||
network_alphas (`dict[str, float]`):
|
||||
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
||||
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
||||
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
||||
@@ -412,7 +412,7 @@ class UNet2DConditionLoadersMixin:
|
||||
|
||||
def save_attn_procs(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
save_directory: str | os.PathLike,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
from typing import TYPE_CHECKING, Dict, List, Union
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
from torch import nn
|
||||
|
||||
@@ -40,9 +40,7 @@ def _translate_into_actual_layer_name(name):
|
||||
return ".".join((updown, block, attn))
|
||||
|
||||
|
||||
def _maybe_expand_lora_scales(
|
||||
unet: "UNet2DConditionModel", weight_scales: List[Union[float, Dict]], default_scale=1.0
|
||||
):
|
||||
def _maybe_expand_lora_scales(unet: "UNet2DConditionModel", weight_scales: list[float | Dict], default_scale=1.0):
|
||||
blocks_with_transformer = {
|
||||
"down": [i for i, block in enumerate(unet.down_blocks) if hasattr(block, "attentions")],
|
||||
"up": [i for i, block in enumerate(unet.up_blocks) if hasattr(block, "attentions")],
|
||||
@@ -64,9 +62,9 @@ def _maybe_expand_lora_scales(
|
||||
|
||||
|
||||
def _maybe_expand_lora_scales_for_one_adapter(
|
||||
scales: Union[float, Dict],
|
||||
blocks_with_transformer: Dict[str, int],
|
||||
transformer_per_block: Dict[str, int],
|
||||
scales: float | Dict,
|
||||
blocks_with_transformer: dict[str, int],
|
||||
transformer_per_block: dict[str, int],
|
||||
model: nn.Module,
|
||||
default_scale: float = 1.0,
|
||||
):
|
||||
@@ -76,9 +74,9 @@ def _maybe_expand_lora_scales_for_one_adapter(
|
||||
Parameters:
|
||||
scales (`Union[float, Dict]`):
|
||||
Scales dict to expand.
|
||||
blocks_with_transformer (`Dict[str, int]`):
|
||||
blocks_with_transformer (`dict[str, int]`):
|
||||
Dict with keys 'up' and 'down', showing which blocks have transformer layers
|
||||
transformer_per_block (`Dict[str, int]`):
|
||||
transformer_per_block (`dict[str, int]`):
|
||||
Dict with keys 'up' and 'down', showing how many transformer layers each block has
|
||||
|
||||
E.g. turns
|
||||
|
||||
@@ -12,13 +12,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class AttnProcsLayers(torch.nn.Module):
|
||||
def __init__(self, state_dict: Dict[str, torch.Tensor]):
|
||||
def __init__(self, state_dict: dict[str, torch.Tensor]):
|
||||
super().__init__()
|
||||
self.layers = torch.nn.ModuleList(state_dict.values())
|
||||
self.mapping = dict(enumerate(state_dict.keys()))
|
||||
|
||||
@@ -102,7 +102,6 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_prx"] = ["PRXTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_sana_video"] = ["SanaVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"]
|
||||
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
|
||||
@@ -205,7 +204,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
PRXTransformer2DModel,
|
||||
QwenImageTransformer2DModel,
|
||||
SanaTransformer2DModel,
|
||||
SanaVideoTransformer3DModel,
|
||||
SD3Transformer2DModel,
|
||||
SkyReelsV2Transformer3DModel,
|
||||
StableAudioDiTModel,
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Literal, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -187,19 +187,17 @@ class ContextParallelOutput:
|
||||
# If the key is a string, it denotes the name of the parameter in the forward function.
|
||||
# If the key is an integer, split_output must be set to True, and it denotes the index of the output
|
||||
# to be split across context parallel region.
|
||||
ContextParallelInputType = Dict[
|
||||
Union[str, int], Union[ContextParallelInput, List[ContextParallelInput], Tuple[ContextParallelInput, ...]]
|
||||
ContextParallelInputType = dict[
|
||||
str | int, ContextParallelInput | list[ContextParallelInput] | tuple[ContextParallelInput, ...]
|
||||
]
|
||||
|
||||
# A dictionary where keys denote the output to be gathered across context parallel region, and the
|
||||
# value denotes the gathering configuration.
|
||||
ContextParallelOutputType = Union[
|
||||
ContextParallelOutput, List[ContextParallelOutput], Tuple[ContextParallelOutput, ...]
|
||||
]
|
||||
ContextParallelOutputType = ContextParallelOutput | list[ContextParallelOutput] | tuple[ContextParallelOutput, ...]
|
||||
|
||||
# A dictionary where keys denote the module id, and the value denotes how the inputs/outputs of
|
||||
# the module should be split/gathered across context parallel region.
|
||||
ContextParallelModelPlan = Dict[str, Union[ContextParallelInputType, ContextParallelOutputType]]
|
||||
ContextParallelModelPlan = dict[str, ContextParallelInputType | ContextParallelOutputType]
|
||||
|
||||
|
||||
# Example of a ContextParallelModelPlan (QwenImageTransformer2DModel):
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from typing import Callable, List, Optional, Union
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -34,11 +34,11 @@ class MultiAdapter(ModelMixin):
|
||||
or saving.
|
||||
|
||||
Args:
|
||||
adapters (`List[T2IAdapter]`, *optional*, defaults to None):
|
||||
adapters (`list[T2IAdapter]`, *optional*, defaults to None):
|
||||
A list of `T2IAdapter` model instances.
|
||||
"""
|
||||
|
||||
def __init__(self, adapters: List["T2IAdapter"]):
|
||||
def __init__(self, adapters: list["T2IAdapter"]):
|
||||
super(MultiAdapter, self).__init__()
|
||||
|
||||
self.num_adapter = len(adapters)
|
||||
@@ -73,7 +73,7 @@ class MultiAdapter(ModelMixin):
|
||||
self.total_downscale_factor = first_adapter_total_downscale_factor
|
||||
self.downscale_factor = first_adapter_downscale_factor
|
||||
|
||||
def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = None) -> List[torch.Tensor]:
|
||||
def forward(self, xs: torch.Tensor, adapter_weights: Optional[list[float]] = None) -> list[torch.Tensor]:
|
||||
r"""
|
||||
Args:
|
||||
xs (`torch.Tensor`):
|
||||
@@ -81,7 +81,7 @@ class MultiAdapter(ModelMixin):
|
||||
models, concatenated along dimension 1(channel dimension). The `channel` dimension should be equal to
|
||||
`num_adapter` * number of channel per image.
|
||||
|
||||
adapter_weights (`List[float]`, *optional*, defaults to None):
|
||||
adapter_weights (`list[float]`, *optional*, defaults to None):
|
||||
A list of floats representing the weights which will be multiplied by each adapter's output before
|
||||
summing them together. If `None`, equal weights will be used for all adapters.
|
||||
"""
|
||||
@@ -104,7 +104,7 @@ class MultiAdapter(ModelMixin):
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
save_directory: str | os.PathLike,
|
||||
is_main_process: bool = True,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
@@ -145,7 +145,7 @@ class MultiAdapter(ModelMixin):
|
||||
model_path_to_save = model_path_to_save + f"_{idx}"
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
def from_pretrained(cls, pretrained_model_path: Optional[str | os.PathLike], **kwargs):
|
||||
r"""
|
||||
Instantiate a pretrained `MultiAdapter` model from multiple pre-trained adapter models.
|
||||
|
||||
@@ -165,7 +165,7 @@ class MultiAdapter(ModelMixin):
|
||||
Override the default `torch.dtype` and load the model under this dtype.
|
||||
output_loading_info(`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
device_map (`str` or `dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
A map that specifies where each submodule should go. It doesn't need to be refined to each
|
||||
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
|
||||
same device.
|
||||
@@ -229,7 +229,7 @@ class T2IAdapter(ModelMixin, ConfigMixin):
|
||||
in_channels (`int`, *optional*, defaults to `3`):
|
||||
The number of channels in the adapter's input (*control image*). Set it to 1 if you're using a gray scale
|
||||
image.
|
||||
channels (`List[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
||||
channels (`list[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
||||
The number of channels in each downsample block's output hidden state. The `len(block_out_channels)`
|
||||
determines the number of downsample blocks in the adapter.
|
||||
num_res_blocks (`int`, *optional*, defaults to `2`):
|
||||
@@ -244,7 +244,7 @@ class T2IAdapter(ModelMixin, ConfigMixin):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
channels: List[int] = [320, 640, 1280, 1280],
|
||||
channels: list[int] = [320, 640, 1280, 1280],
|
||||
num_res_blocks: int = 2,
|
||||
downscale_factor: int = 8,
|
||||
adapter_type: str = "full_adapter",
|
||||
@@ -263,7 +263,7 @@ class T2IAdapter(ModelMixin, ConfigMixin):
|
||||
"'full_adapter_xl' or 'light_adapter'."
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
|
||||
r"""
|
||||
This function processes the input tensor `x` through the adapter model and returns a list of feature tensors,
|
||||
each representing information extracted at a different scale from the input. The length of the list is
|
||||
@@ -295,7 +295,7 @@ class FullAdapter(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
channels: List[int] = [320, 640, 1280, 1280],
|
||||
channels: list[int] = [320, 640, 1280, 1280],
|
||||
num_res_blocks: int = 2,
|
||||
downscale_factor: int = 8,
|
||||
):
|
||||
@@ -318,7 +318,7 @@ class FullAdapter(nn.Module):
|
||||
|
||||
self.total_downscale_factor = downscale_factor * 2 ** (len(channels) - 1)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
|
||||
r"""
|
||||
This method processes the input tensor `x` through the FullAdapter model and performs operations including
|
||||
pixel unshuffling, convolution, and a stack of AdapterBlocks. It returns a list of feature tensors, each
|
||||
@@ -345,7 +345,7 @@ class FullAdapterXL(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
channels: List[int] = [320, 640, 1280, 1280],
|
||||
channels: list[int] = [320, 640, 1280, 1280],
|
||||
num_res_blocks: int = 2,
|
||||
downscale_factor: int = 16,
|
||||
):
|
||||
@@ -370,7 +370,7 @@ class FullAdapterXL(nn.Module):
|
||||
# XL has only one downsampling AdapterBlock.
|
||||
self.total_downscale_factor = downscale_factor * 2
|
||||
|
||||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
|
||||
r"""
|
||||
This method takes the tensor x as input and processes it through FullAdapterXL model. It consists of operations
|
||||
including unshuffling pixels, applying convolution layer and appending each block into list of feature tensors.
|
||||
@@ -473,7 +473,7 @@ class LightAdapter(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
channels: List[int] = [320, 640, 1280],
|
||||
channels: list[int] = [320, 640, 1280],
|
||||
num_res_blocks: int = 4,
|
||||
downscale_factor: int = 8,
|
||||
):
|
||||
@@ -496,7 +496,7 @@ class LightAdapter(nn.Module):
|
||||
|
||||
self.total_downscale_factor = downscale_factor * (2 ** len(channels))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
|
||||
r"""
|
||||
This method takes the input tensor x and performs downscaling and appends it in list of feature tensors. Each
|
||||
feature tensor corresponds to a different level of processing within the LightAdapter.
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -38,7 +38,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
class AttentionMixin:
|
||||
@property
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
def attn_processors(self) -> dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
@@ -47,7 +47,7 @@ class AttentionMixin:
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
@@ -61,7 +61,7 @@ class AttentionMixin:
|
||||
|
||||
return processors
|
||||
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(self, processor: AttentionProcessor | dict[str, AttentionProcessor]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -184,7 +184,7 @@ class AttentionModuleMixin:
|
||||
def set_use_xla_flash_attention(
|
||||
self,
|
||||
use_xla_flash_attention: bool,
|
||||
partition_spec: Optional[Tuple[Optional[str], ...]] = None,
|
||||
partition_spec: Optional[tuple[Optional[str], ...]] = None,
|
||||
is_flux=False,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -193,7 +193,7 @@ class AttentionModuleMixin:
|
||||
Args:
|
||||
use_xla_flash_attention (`bool`):
|
||||
Whether to use pallas flash attention kernel from `torch_xla` or not.
|
||||
partition_spec (`Tuple[]`, *optional*):
|
||||
partition_spec (`tuple[]`, *optional*):
|
||||
Specify the partition specification if using SPMD. Otherwise None.
|
||||
is_flux (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model is a Flux model.
|
||||
@@ -669,8 +669,8 @@ class JointTransformerBlock(nn.Module):
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor,
|
||||
temb: torch.FloatTensor,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
joint_attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||
if self.use_dual_attention:
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
|
||||
@@ -950,9 +950,9 @@ class BasicTransformerBlock(nn.Module):
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
cross_attention_kwargs: dict[str, Any] = None,
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
added_cond_kwargs: Optional[dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
if cross_attention_kwargs is not None:
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
@@ -1487,7 +1487,7 @@ class FreeNoiseTransformerBlock(nn.Module):
|
||||
self._chunk_size = None
|
||||
self._chunk_dim = 0
|
||||
|
||||
def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]:
|
||||
def _get_frame_indices(self, num_frames: int) -> list[tuple[int, int]]:
|
||||
frame_indices = []
|
||||
for i in range(0, num_frames - self.context_length + 1, self.context_stride):
|
||||
window_start = i
|
||||
@@ -1495,7 +1495,7 @@ class FreeNoiseTransformerBlock(nn.Module):
|
||||
frame_indices.append((window_start, window_end))
|
||||
return frame_indices
|
||||
|
||||
def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
|
||||
def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> list[float]:
|
||||
if weighting_scheme == "flat":
|
||||
weights = [1.0] * num_frames
|
||||
|
||||
@@ -1545,7 +1545,7 @@ class FreeNoiseTransformerBlock(nn.Module):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
cross_attention_kwargs: dict[str, Any] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
|
||||
@@ -12,12 +12,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
import math
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -228,7 +230,7 @@ class _AttentionBackendRegistry:
|
||||
def register(
|
||||
cls,
|
||||
backend: AttentionBackendName,
|
||||
constraints: Optional[List[Callable]] = None,
|
||||
constraints: Optional[list[Callable]] = None,
|
||||
supports_context_parallel: bool = False,
|
||||
):
|
||||
logger.debug(f"Registering attention backend: {backend} with constraints: {constraints}")
|
||||
@@ -263,7 +265,7 @@ class _AttentionBackendRegistry:
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE):
|
||||
def attention_backend(backend: str | AttentionBackendName = AttentionBackendName.NATIVE):
|
||||
"""
|
||||
Context manager to set the active attention backend.
|
||||
"""
|
||||
@@ -291,7 +293,7 @@ def dispatch_attention_fn(
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
enable_gqa: bool = False,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
*,
|
||||
backend: Optional[AttentionBackendName] = None,
|
||||
parallel_config: Optional["ParallelConfig"] = None,
|
||||
@@ -595,7 +597,7 @@ def _wrapped_flash_attn_3(
|
||||
pack_gqa: Optional[bool] = None,
|
||||
deterministic: bool = False,
|
||||
sm_margin: int = 0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Hardcoded for now because pytorch does not support tuple/int type hints
|
||||
window_size = (-1, -1)
|
||||
out, lse, *_ = flash_attn_3_func(
|
||||
@@ -637,7 +639,7 @@ def _(
|
||||
pack_gqa: Optional[bool] = None,
|
||||
deterministic: bool = False,
|
||||
sm_margin: int = 0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
window_size = (-1, -1) # noqa: F841
|
||||
# A lot of the parameters here are not yet used in any way within diffusers.
|
||||
# We can safely ignore for now and keep the fake op shape propagation simple.
|
||||
@@ -649,86 +651,6 @@ def _(
|
||||
# ===== Helper functions to use attention backends with templated CP autograd functions =====
|
||||
|
||||
|
||||
def _native_attention_forward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
enable_gqa: bool = False,
|
||||
return_lse: bool = False,
|
||||
_save_ctx: bool = True,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
):
|
||||
# Native attention does not return_lse
|
||||
if return_lse:
|
||||
raise ValueError("Native attention does not support return_lse=True")
|
||||
|
||||
# used for backward pass
|
||||
if _save_ctx:
|
||||
ctx.save_for_backward(query, key, value)
|
||||
ctx.attn_mask = attn_mask
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.is_causal = is_causal
|
||||
ctx.scale = scale
|
||||
ctx.enable_gqa = enable_gqa
|
||||
|
||||
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=is_causal,
|
||||
scale=scale,
|
||||
enable_gqa=enable_gqa,
|
||||
)
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def _native_attention_backward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
grad_out: torch.Tensor,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
query, key, value = ctx.saved_tensors
|
||||
|
||||
query.requires_grad_(True)
|
||||
key.requires_grad_(True)
|
||||
value.requires_grad_(True)
|
||||
|
||||
query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query=query_t,
|
||||
key=key_t,
|
||||
value=value_t,
|
||||
attn_mask=ctx.attn_mask,
|
||||
dropout_p=ctx.dropout_p,
|
||||
is_causal=ctx.is_causal,
|
||||
scale=ctx.scale,
|
||||
enable_gqa=ctx.enable_gqa,
|
||||
)
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
|
||||
grad_out_t = grad_out.permute(0, 2, 1, 3)
|
||||
grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad(
|
||||
outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out_t, retain_graph=False
|
||||
)
|
||||
|
||||
grad_query = grad_query_t.permute(0, 2, 1, 3)
|
||||
grad_key = grad_key_t.permute(0, 2, 1, 3)
|
||||
grad_value = grad_value_t.permute(0, 2, 1, 3)
|
||||
|
||||
return grad_query, grad_key, grad_value
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14958
|
||||
# forward declaration:
|
||||
# aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
|
||||
@@ -1415,7 +1337,7 @@ def _flash_attention_3_hub(
|
||||
value: torch.Tensor,
|
||||
scale: Optional[float] = None,
|
||||
is_causal: bool = False,
|
||||
window_size: Tuple[int, int] = (-1, -1),
|
||||
window_size: tuple[int, int] = (-1, -1),
|
||||
softcap: float = 0.0,
|
||||
deterministic: bool = False,
|
||||
return_attn_probs: bool = False,
|
||||
@@ -1545,7 +1467,7 @@ def _native_flex_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[Union[torch.Tensor, "flex_attention.BlockMask"]] = None,
|
||||
attn_mask: Optional[torch.Tensor | "flex_attention.BlockMask"] = None,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
enable_gqa: bool = False,
|
||||
@@ -1603,7 +1525,6 @@ def _native_flex_attention(
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.NATIVE,
|
||||
constraints=[_check_device, _check_shape],
|
||||
supports_context_parallel=True,
|
||||
)
|
||||
def _native_attention(
|
||||
query: torch.Tensor,
|
||||
@@ -1619,35 +1540,18 @@ def _native_attention(
|
||||
) -> torch.Tensor:
|
||||
if return_lse:
|
||||
raise ValueError("Native attention backend does not support setting `return_lse=True`.")
|
||||
if _parallel_config is None:
|
||||
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=is_causal,
|
||||
scale=scale,
|
||||
enable_gqa=enable_gqa,
|
||||
)
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
else:
|
||||
out = _templated_context_parallel_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
scale,
|
||||
enable_gqa,
|
||||
return_lse,
|
||||
forward_op=_native_attention_forward_op,
|
||||
backward_op=_native_attention_backward_op,
|
||||
_parallel_config=_parallel_config,
|
||||
)
|
||||
|
||||
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=is_causal,
|
||||
scale=scale,
|
||||
enable_gqa=enable_gqa,
|
||||
)
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
import math
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -309,7 +309,7 @@ class Attention(nn.Module):
|
||||
def set_use_xla_flash_attention(
|
||||
self,
|
||||
use_xla_flash_attention: bool,
|
||||
partition_spec: Optional[Tuple[Optional[str], ...]] = None,
|
||||
partition_spec: Optional[tuple[Optional[str], ...]] = None,
|
||||
is_flux=False,
|
||||
) -> None:
|
||||
r"""
|
||||
@@ -318,7 +318,7 @@ class Attention(nn.Module):
|
||||
Args:
|
||||
use_xla_flash_attention (`bool`):
|
||||
Whether to use pallas flash attention kernel from `torch_xla` or not.
|
||||
partition_spec (`Tuple[]`, *optional*):
|
||||
partition_spec (`tuple[]`, *optional*):
|
||||
Specify the partition specification if using SPMD. Otherwise None.
|
||||
"""
|
||||
if use_xla_flash_attention:
|
||||
@@ -872,7 +872,7 @@ class SanaMultiscaleLinearAttention(nn.Module):
|
||||
attention_head_dim: int = 8,
|
||||
mult: float = 1.0,
|
||||
norm_type: str = "batch_norm",
|
||||
kernel_sizes: Tuple[int, ...] = (5,),
|
||||
kernel_sizes: tuple[int, ...] = (5,),
|
||||
eps: float = 1e-15,
|
||||
residual_connection: bool = False,
|
||||
):
|
||||
@@ -2790,7 +2790,7 @@ class XLAFlashAttnProcessor2_0:
|
||||
Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
|
||||
"""
|
||||
|
||||
def __init__(self, partition_spec: Optional[Tuple[Optional[str], ...]] = None):
|
||||
def __init__(self, partition_spec: Optional[tuple[Optional[str], ...]] = None):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
@@ -3001,7 +3001,7 @@ class StableAudioAttnProcessor2_0:
|
||||
def apply_partial_rotary_emb(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
freqs_cis: Tuple[torch.Tensor],
|
||||
freqs_cis: tuple[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
from .embeddings import apply_rotary_emb
|
||||
|
||||
@@ -4212,9 +4212,9 @@ class IPAdapterAttnProcessor(nn.Module):
|
||||
The hidden size of the attention layer.
|
||||
cross_attention_dim (`int`):
|
||||
The number of channels in the `encoder_hidden_states`.
|
||||
num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
|
||||
num_tokens (`int`, `tuple[int]` or `list[int]`, defaults to `(4,)`):
|
||||
The context length of the image features.
|
||||
scale (`float` or List[`float`], defaults to 1.0):
|
||||
scale (`float` or list[`float`], defaults to 1.0):
|
||||
the weight scale of image prompt.
|
||||
"""
|
||||
|
||||
@@ -4305,7 +4305,7 @@ class IPAdapterAttnProcessor(nn.Module):
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
if ip_adapter_masks is not None:
|
||||
if not isinstance(ip_adapter_masks, List):
|
||||
if not isinstance(ip_adapter_masks, list):
|
||||
# for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
|
||||
ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
|
||||
if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
|
||||
@@ -4412,9 +4412,9 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
|
||||
The hidden size of the attention layer.
|
||||
cross_attention_dim (`int`):
|
||||
The number of channels in the `encoder_hidden_states`.
|
||||
num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
|
||||
num_tokens (`int`, `tuple[int]` or `list[int]`, defaults to `(4,)`):
|
||||
The context length of the image features.
|
||||
scale (`float` or `List[float]`, defaults to 1.0):
|
||||
scale (`float` or `list[float]`, defaults to 1.0):
|
||||
the weight scale of image prompt.
|
||||
"""
|
||||
|
||||
@@ -4524,7 +4524,7 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if ip_adapter_masks is not None:
|
||||
if not isinstance(ip_adapter_masks, List):
|
||||
if not isinstance(ip_adapter_masks, list):
|
||||
# for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
|
||||
ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
|
||||
if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
|
||||
@@ -4644,9 +4644,9 @@ class IPAdapterXFormersAttnProcessor(torch.nn.Module):
|
||||
The hidden size of the attention layer.
|
||||
cross_attention_dim (`int`):
|
||||
The number of channels in the `encoder_hidden_states`.
|
||||
num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
|
||||
num_tokens (`int`, `tuple[int]` or `list[int]`, defaults to `(4,)`):
|
||||
The context length of the image features.
|
||||
scale (`float` or `List[float]`, defaults to 1.0):
|
||||
scale (`float` or `list[float]`, defaults to 1.0):
|
||||
the weight scale of image prompt.
|
||||
attention_op (`Callable`, *optional*, defaults to `None`):
|
||||
The base
|
||||
@@ -4763,7 +4763,7 @@ class IPAdapterXFormersAttnProcessor(torch.nn.Module):
|
||||
|
||||
if ip_hidden_states:
|
||||
if ip_adapter_masks is not None:
|
||||
if not isinstance(ip_adapter_masks, List):
|
||||
if not isinstance(ip_adapter_masks, list):
|
||||
# for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
|
||||
ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
|
||||
if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
|
||||
@@ -5622,56 +5622,56 @@ CROSS_ATTENTION_PROCESSORS = (
|
||||
FluxIPAdapterJointAttnProcessor2_0,
|
||||
)
|
||||
|
||||
AttentionProcessor = Union[
|
||||
AttnProcessor,
|
||||
CustomDiffusionAttnProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnAddedKVProcessor2_0,
|
||||
JointAttnProcessor2_0,
|
||||
PAGJointAttnProcessor2_0,
|
||||
PAGCFGJointAttnProcessor2_0,
|
||||
FusedJointAttnProcessor2_0,
|
||||
AllegroAttnProcessor2_0,
|
||||
AuraFlowAttnProcessor2_0,
|
||||
FusedAuraFlowAttnProcessor2_0,
|
||||
FluxAttnProcessor2_0,
|
||||
FluxAttnProcessor2_0_NPU,
|
||||
FusedFluxAttnProcessor2_0,
|
||||
FusedFluxAttnProcessor2_0_NPU,
|
||||
CogVideoXAttnProcessor2_0,
|
||||
FusedCogVideoXAttnProcessor2_0,
|
||||
XFormersAttnAddedKVProcessor,
|
||||
XFormersAttnProcessor,
|
||||
XLAFlashAttnProcessor2_0,
|
||||
AttnProcessorNPU,
|
||||
AttnProcessor2_0,
|
||||
MochiVaeAttnProcessor2_0,
|
||||
MochiAttnProcessor2_0,
|
||||
StableAudioAttnProcessor2_0,
|
||||
HunyuanAttnProcessor2_0,
|
||||
FusedHunyuanAttnProcessor2_0,
|
||||
PAGHunyuanAttnProcessor2_0,
|
||||
PAGCFGHunyuanAttnProcessor2_0,
|
||||
LuminaAttnProcessor2_0,
|
||||
FusedAttnProcessor2_0,
|
||||
CustomDiffusionXFormersAttnProcessor,
|
||||
CustomDiffusionAttnProcessor2_0,
|
||||
SlicedAttnProcessor,
|
||||
SlicedAttnAddedKVProcessor,
|
||||
SanaLinearAttnProcessor2_0,
|
||||
PAGCFGSanaLinearAttnProcessor2_0,
|
||||
PAGIdentitySanaLinearAttnProcessor2_0,
|
||||
SanaMultiscaleLinearAttention,
|
||||
SanaMultiscaleAttnProcessor2_0,
|
||||
SanaMultiscaleAttentionProjection,
|
||||
IPAdapterAttnProcessor,
|
||||
IPAdapterAttnProcessor2_0,
|
||||
IPAdapterXFormersAttnProcessor,
|
||||
SD3IPAdapterJointAttnProcessor2_0,
|
||||
PAGIdentitySelfAttnProcessor2_0,
|
||||
PAGCFGIdentitySelfAttnProcessor2_0,
|
||||
LoRAAttnProcessor,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
LoRAAttnAddedKVProcessor,
|
||||
]
|
||||
AttentionProcessor = (
|
||||
AttnProcessor
|
||||
| CustomDiffusionAttnProcessor
|
||||
| AttnAddedKVProcessor
|
||||
| AttnAddedKVProcessor2_0
|
||||
| JointAttnProcessor2_0
|
||||
| PAGJointAttnProcessor2_0
|
||||
| PAGCFGJointAttnProcessor2_0
|
||||
| FusedJointAttnProcessor2_0
|
||||
| AllegroAttnProcessor2_0
|
||||
| AuraFlowAttnProcessor2_0
|
||||
| FusedAuraFlowAttnProcessor2_0
|
||||
| FluxAttnProcessor2_0
|
||||
| FluxAttnProcessor2_0_NPU
|
||||
| FusedFluxAttnProcessor2_0
|
||||
| FusedFluxAttnProcessor2_0_NPU
|
||||
| CogVideoXAttnProcessor2_0
|
||||
| FusedCogVideoXAttnProcessor2_0
|
||||
| XFormersAttnAddedKVProcessor
|
||||
| XFormersAttnProcessor
|
||||
| XLAFlashAttnProcessor2_0
|
||||
| AttnProcessorNPU
|
||||
| AttnProcessor2_0
|
||||
| MochiVaeAttnProcessor2_0
|
||||
| MochiAttnProcessor2_0
|
||||
| StableAudioAttnProcessor2_0
|
||||
| HunyuanAttnProcessor2_0
|
||||
| FusedHunyuanAttnProcessor2_0
|
||||
| PAGHunyuanAttnProcessor2_0
|
||||
| PAGCFGHunyuanAttnProcessor2_0
|
||||
| LuminaAttnProcessor2_0
|
||||
| FusedAttnProcessor2_0
|
||||
| CustomDiffusionXFormersAttnProcessor
|
||||
| CustomDiffusionAttnProcessor2_0
|
||||
| SlicedAttnProcessor
|
||||
| SlicedAttnAddedKVProcessor
|
||||
| SanaLinearAttnProcessor2_0
|
||||
| PAGCFGSanaLinearAttnProcessor2_0
|
||||
| PAGIdentitySanaLinearAttnProcessor2_0
|
||||
| SanaMultiscaleLinearAttention
|
||||
| SanaMultiscaleAttnProcessor2_0
|
||||
| SanaMultiscaleAttentionProjection
|
||||
| IPAdapterAttnProcessor
|
||||
| IPAdapterAttnProcessor2_0
|
||||
| IPAdapterXFormersAttnProcessor
|
||||
| SD3IPAdapterJointAttnProcessor2_0
|
||||
| PAGIdentitySelfAttnProcessor2_0
|
||||
| PAGCFGIdentitySelfAttnProcessor2_0
|
||||
| LoRAAttnProcessor
|
||||
| LoRAAttnProcessor2_0
|
||||
| LoRAXFormersAttnProcessor
|
||||
| LoRAAttnAddedKVProcessor
|
||||
)
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
@@ -37,7 +37,7 @@ class AutoModel(ConfigMixin):
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLike]] = None, **kwargs):
|
||||
def from_pretrained(cls, pretrained_model_or_path: Optional[str | os.PathLike] = None, **kwargs):
|
||||
r"""
|
||||
Instantiate a pretrained PyTorch model from a pretrained model configuration.
|
||||
|
||||
@@ -61,7 +61,7 @@ class AutoModel(ConfigMixin):
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
proxies (`dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
output_loading_info (`bool`, *optional*, defaults to `False`):
|
||||
@@ -83,7 +83,7 @@ class AutoModel(ConfigMixin):
|
||||
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
|
||||
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
||||
information.
|
||||
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
device_map (`str` or `dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
A map that specifies where each submodule should go. It doesn't need to be defined for each
|
||||
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
|
||||
same device. Defaults to `None`, meaning that the model will be loaded on CPU.
|
||||
@@ -147,13 +147,14 @@ class AutoModel(ConfigMixin):
|
||||
"force_download",
|
||||
"local_files_only",
|
||||
"proxies",
|
||||
"resume_download",
|
||||
"revision",
|
||||
"token",
|
||||
]
|
||||
hub_kwargs = {name: kwargs.pop(name, None) for name in hub_kwargs_names}
|
||||
|
||||
# load_config_kwargs uses the same hub kwargs minus subfolder and resume_download
|
||||
load_config_kwargs = {k: v for k, v in hub_kwargs.items() if k not in ["subfolder"]}
|
||||
load_config_kwargs = {k: v for k, v in hub_kwargs.items() if k not in ["subfolder", "resume_download"]}
|
||||
|
||||
library = None
|
||||
orig_class_name = None
|
||||
@@ -204,6 +205,7 @@ class AutoModel(ConfigMixin):
|
||||
module_file=module_file,
|
||||
class_name=class_name,
|
||||
**hub_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -34,16 +34,16 @@ class AsymmetricAutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
Parameters:
|
||||
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
||||
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
||||
Tuple of downsample block types.
|
||||
down_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
||||
Tuple of down block output channels.
|
||||
down_block_types (`tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
||||
tuple of downsample block types.
|
||||
down_block_out_channels (`tuple[int]`, *optional*, defaults to `(64,)`):
|
||||
tuple of down block output channels.
|
||||
layers_per_down_block (`int`, *optional*, defaults to `1`):
|
||||
Number layers for down block.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
||||
Tuple of upsample block types.
|
||||
up_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
||||
Tuple of up block output channels.
|
||||
up_block_types (`tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
||||
tuple of upsample block types.
|
||||
up_block_out_channels (`tuple[int]`, *optional*, defaults to `(64,)`):
|
||||
tuple of up block output channels.
|
||||
layers_per_up_block (`int`, *optional*, defaults to `1`):
|
||||
Number layers for up block.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
||||
@@ -67,11 +67,11 @@ class AsymmetricAutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
|
||||
down_block_out_channels: Tuple[int, ...] = (64,),
|
||||
down_block_types: tuple[str, ...] = ("DownEncoderBlock2D",),
|
||||
down_block_out_channels: tuple[int, ...] = (64,),
|
||||
layers_per_down_block: int = 1,
|
||||
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
|
||||
up_block_out_channels: Tuple[int, ...] = (64,),
|
||||
up_block_types: tuple[str, ...] = ("UpDecoderBlock2D",),
|
||||
up_block_out_channels: tuple[int, ...] = (64,),
|
||||
layers_per_up_block: int = 1,
|
||||
act_fn: str = "silu",
|
||||
latent_channels: int = 4,
|
||||
@@ -111,7 +111,7 @@ class AsymmetricAutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
self.register_to_config(force_upcast=False)
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[AutoencoderKLOutput, Tuple[torch.Tensor]]:
|
||||
def encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput | tuple[torch.Tensor]:
|
||||
h = self.encoder(x)
|
||||
moments = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
@@ -127,7 +127,7 @@ class AsymmetricAutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
image: Optional[torch.Tensor] = None,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
|
||||
) -> DecoderOutput | tuple[torch.Tensor]:
|
||||
z = self.post_quant_conv(z)
|
||||
dec = self.decoder(z, image, mask)
|
||||
|
||||
@@ -144,7 +144,7 @@ class AsymmetricAutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
image: Optional[torch.Tensor] = None,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
|
||||
) -> DecoderOutput | tuple[torch.Tensor]:
|
||||
decoded = self._decode(z, image, mask).sample
|
||||
|
||||
if not return_dict:
|
||||
@@ -159,7 +159,7 @@ class AsymmetricAutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
|
||||
) -> DecoderOutput | tuple[torch.Tensor]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.Tensor`): Input sample.
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -68,7 +68,7 @@ class EfficientViTBlock(nn.Module):
|
||||
in_channels: int,
|
||||
mult: float = 1.0,
|
||||
attention_head_dim: int = 32,
|
||||
qkv_multiscales: Tuple[int, ...] = (5,),
|
||||
qkv_multiscales: tuple[int, ...] = (5,),
|
||||
norm_type: str = "batch_norm",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -102,7 +102,7 @@ def get_block(
|
||||
attention_head_dim: int,
|
||||
norm_type: str,
|
||||
act_fn: str,
|
||||
qkv_mutliscales: Tuple[int] = (),
|
||||
qkv_mutliscales: tuple[int] = (),
|
||||
):
|
||||
if block_type == "ResBlock":
|
||||
block = ResBlock(in_channels, out_channels, norm_type, act_fn)
|
||||
@@ -205,10 +205,10 @@ class Encoder(nn.Module):
|
||||
in_channels: int,
|
||||
latent_channels: int,
|
||||
attention_head_dim: int = 32,
|
||||
block_type: Union[str, Tuple[str]] = "ResBlock",
|
||||
block_out_channels: Tuple[int] = (128, 256, 512, 512, 1024, 1024),
|
||||
layers_per_block: Tuple[int] = (2, 2, 2, 2, 2, 2),
|
||||
qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
||||
block_type: str | tuple[str] = "ResBlock",
|
||||
block_out_channels: tuple[int] = (128, 256, 512, 512, 1024, 1024),
|
||||
layers_per_block: tuple[int] = (2, 2, 2, 2, 2, 2),
|
||||
qkv_multiscales: tuple[tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
||||
downsample_block_type: str = "pixel_unshuffle",
|
||||
out_shortcut: bool = True,
|
||||
):
|
||||
@@ -291,12 +291,12 @@ class Decoder(nn.Module):
|
||||
in_channels: int,
|
||||
latent_channels: int,
|
||||
attention_head_dim: int = 32,
|
||||
block_type: Union[str, Tuple[str]] = "ResBlock",
|
||||
block_out_channels: Tuple[int] = (128, 256, 512, 512, 1024, 1024),
|
||||
layers_per_block: Tuple[int] = (2, 2, 2, 2, 2, 2),
|
||||
qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
||||
norm_type: Union[str, Tuple[str]] = "rms_norm",
|
||||
act_fn: Union[str, Tuple[str]] = "silu",
|
||||
block_type: str | tuple[str] = "ResBlock",
|
||||
block_out_channels: tuple[int] = (128, 256, 512, 512, 1024, 1024),
|
||||
layers_per_block: tuple[int] = (2, 2, 2, 2, 2, 2),
|
||||
qkv_multiscales: tuple[tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
||||
norm_type: str | tuple[str] = "rms_norm",
|
||||
act_fn: str | tuple[str] = "silu",
|
||||
upsample_block_type: str = "pixel_shuffle",
|
||||
in_shortcut: bool = True,
|
||||
conv_act_fn: str = "relu",
|
||||
@@ -391,29 +391,29 @@ class AutoencoderDC(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
|
||||
The number of input channels in samples.
|
||||
latent_channels (`int`, defaults to `32`):
|
||||
The number of channels in the latent space representation.
|
||||
encoder_block_types (`Union[str, Tuple[str]]`, defaults to `"ResBlock"`):
|
||||
encoder_block_types (`Union[str, tuple[str]]`, defaults to `"ResBlock"`):
|
||||
The type(s) of block to use in the encoder.
|
||||
decoder_block_types (`Union[str, Tuple[str]]`, defaults to `"ResBlock"`):
|
||||
decoder_block_types (`Union[str, tuple[str]]`, defaults to `"ResBlock"`):
|
||||
The type(s) of block to use in the decoder.
|
||||
encoder_block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512, 1024, 1024)`):
|
||||
encoder_block_out_channels (`tuple[int, ...]`, defaults to `(128, 256, 512, 512, 1024, 1024)`):
|
||||
The number of output channels for each block in the encoder.
|
||||
decoder_block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512, 1024, 1024)`):
|
||||
decoder_block_out_channels (`tuple[int, ...]`, defaults to `(128, 256, 512, 512, 1024, 1024)`):
|
||||
The number of output channels for each block in the decoder.
|
||||
encoder_layers_per_block (`Tuple[int]`, defaults to `(2, 2, 2, 3, 3, 3)`):
|
||||
encoder_layers_per_block (`tuple[int]`, defaults to `(2, 2, 2, 3, 3, 3)`):
|
||||
The number of layers per block in the encoder.
|
||||
decoder_layers_per_block (`Tuple[int]`, defaults to `(3, 3, 3, 3, 3, 3)`):
|
||||
decoder_layers_per_block (`tuple[int]`, defaults to `(3, 3, 3, 3, 3, 3)`):
|
||||
The number of layers per block in the decoder.
|
||||
encoder_qkv_multiscales (`Tuple[Tuple[int, ...], ...]`, defaults to `((), (), (), (5,), (5,), (5,))`):
|
||||
encoder_qkv_multiscales (`tuple[tuple[int, ...], ...]`, defaults to `((), (), (), (5,), (5,), (5,))`):
|
||||
Multi-scale configurations for the encoder's QKV (query-key-value) transformations.
|
||||
decoder_qkv_multiscales (`Tuple[Tuple[int, ...], ...]`, defaults to `((), (), (), (5,), (5,), (5,))`):
|
||||
decoder_qkv_multiscales (`tuple[tuple[int, ...], ...]`, defaults to `((), (), (), (5,), (5,), (5,))`):
|
||||
Multi-scale configurations for the decoder's QKV (query-key-value) transformations.
|
||||
upsample_block_type (`str`, defaults to `"pixel_shuffle"`):
|
||||
The type of block to use for upsampling in the decoder.
|
||||
downsample_block_type (`str`, defaults to `"pixel_unshuffle"`):
|
||||
The type of block to use for downsampling in the encoder.
|
||||
decoder_norm_types (`Union[str, Tuple[str]]`, defaults to `"rms_norm"`):
|
||||
decoder_norm_types (`Union[str, tuple[str]]`, defaults to `"rms_norm"`):
|
||||
The normalization type(s) to use in the decoder.
|
||||
decoder_act_fns (`Union[str, Tuple[str]]`, defaults to `"silu"`):
|
||||
decoder_act_fns (`Union[str, tuple[str]]`, defaults to `"silu"`):
|
||||
The activation function(s) to use in the decoder.
|
||||
encoder_out_shortcut (`bool`, defaults to `True`):
|
||||
Whether to use shortcut at the end of the encoder.
|
||||
@@ -436,18 +436,18 @@ class AutoencoderDC(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
|
||||
in_channels: int = 3,
|
||||
latent_channels: int = 32,
|
||||
attention_head_dim: int = 32,
|
||||
encoder_block_types: Union[str, Tuple[str]] = "ResBlock",
|
||||
decoder_block_types: Union[str, Tuple[str]] = "ResBlock",
|
||||
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
|
||||
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
|
||||
encoder_layers_per_block: Tuple[int] = (2, 2, 2, 3, 3, 3),
|
||||
decoder_layers_per_block: Tuple[int] = (3, 3, 3, 3, 3, 3),
|
||||
encoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
||||
decoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
||||
encoder_block_types: str | tuple[str] = "ResBlock",
|
||||
decoder_block_types: str | tuple[str] = "ResBlock",
|
||||
encoder_block_out_channels: tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
|
||||
decoder_block_out_channels: tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
|
||||
encoder_layers_per_block: tuple[int] = (2, 2, 2, 3, 3, 3),
|
||||
decoder_layers_per_block: tuple[int] = (3, 3, 3, 3, 3, 3),
|
||||
encoder_qkv_multiscales: tuple[tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
||||
decoder_qkv_multiscales: tuple[tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
||||
upsample_block_type: str = "pixel_shuffle",
|
||||
downsample_block_type: str = "pixel_unshuffle",
|
||||
decoder_norm_types: Union[str, Tuple[str]] = "rms_norm",
|
||||
decoder_act_fns: Union[str, Tuple[str]] = "silu",
|
||||
decoder_norm_types: str | tuple[str] = "rms_norm",
|
||||
decoder_act_fns: str | tuple[str] = "silu",
|
||||
encoder_out_shortcut: bool = True,
|
||||
decoder_in_shortcut: bool = True,
|
||||
decoder_conv_act_fn: str = "relu",
|
||||
@@ -547,7 +547,7 @@ class AutoencoderDC(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
|
||||
return encoded
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[EncoderOutput, Tuple[torch.Tensor]]:
|
||||
def encode(self, x: torch.Tensor, return_dict: bool = True) -> EncoderOutput | tuple[torch.Tensor]:
|
||||
r"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
@@ -581,7 +581,7 @@ class AutoencoderDC(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
|
||||
return decoded
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | tuple[torch.Tensor]:
|
||||
r"""
|
||||
Decode a batch of images.
|
||||
|
||||
@@ -665,7 +665,7 @@ class AutoencoderDC(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
|
||||
return (encoded,)
|
||||
return EncoderOutput(latent=encoded)
|
||||
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
batch_size, num_channels, height, width = z.shape
|
||||
|
||||
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -45,12 +45,12 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
|
||||
Parameters:
|
||||
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
||||
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
||||
Tuple of downsample block types.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
||||
Tuple of upsample block types.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
||||
Tuple of block output channels.
|
||||
down_block_types (`tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
||||
tuple of downsample block types.
|
||||
up_block_types (`tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
||||
tuple of upsample block types.
|
||||
block_out_channels (`tuple[int]`, *optional*, defaults to `(64,)`):
|
||||
tuple of block output channels.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
||||
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
|
||||
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
|
||||
@@ -78,9 +78,9 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
|
||||
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
|
||||
block_out_channels: Tuple[int] = (64,),
|
||||
down_block_types: tuple[str] = ("DownEncoderBlock2D",),
|
||||
up_block_types: tuple[str] = ("UpDecoderBlock2D",),
|
||||
block_out_channels: tuple[int] = (64,),
|
||||
layers_per_block: int = 1,
|
||||
act_fn: str = "silu",
|
||||
latent_channels: int = 4,
|
||||
@@ -88,8 +88,8 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
|
||||
sample_size: int = 32,
|
||||
scaling_factor: float = 0.18215,
|
||||
shift_factor: Optional[float] = None,
|
||||
latents_mean: Optional[Tuple[float]] = None,
|
||||
latents_std: Optional[Tuple[float]] = None,
|
||||
latents_mean: Optional[tuple[float]] = None,
|
||||
latents_std: Optional[tuple[float]] = None,
|
||||
force_upcast: bool = True,
|
||||
use_quant_conv: bool = True,
|
||||
use_post_quant_conv: bool = True,
|
||||
@@ -140,7 +140,7 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
def attn_processors(self) -> dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
@@ -149,7 +149,7 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
@@ -164,7 +164,7 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(self, processor: AttentionProcessor | dict[str, AttentionProcessor]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -229,7 +229,7 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]:
|
||||
"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
@@ -255,7 +255,7 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
|
||||
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
|
||||
return self.tiled_decode(z, return_dict=return_dict)
|
||||
|
||||
@@ -272,7 +272,7 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
|
||||
@apply_forward_hook
|
||||
def decode(
|
||||
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
|
||||
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
) -> DecoderOutput | torch.FloatTensor:
|
||||
"""
|
||||
Decode a batch of images.
|
||||
|
||||
@@ -420,7 +420,7 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
|
||||
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
r"""
|
||||
Decode a batch of images using a tiled decoder.
|
||||
|
||||
@@ -475,7 +475,7 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
) -> DecoderOutput | torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.Tensor`): Input sample.
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -417,14 +417,14 @@ class AllegroEncoder3D(nn.Module):
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
down_block_types: tuple[str, ...] = (
|
||||
"AllegroDownBlock3D",
|
||||
"AllegroDownBlock3D",
|
||||
"AllegroDownBlock3D",
|
||||
"AllegroDownBlock3D",
|
||||
),
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
temporal_downsample_blocks: Tuple[bool, ...] = [True, True, False, False],
|
||||
block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
|
||||
temporal_downsample_blocks: tuple[bool, ...] = [True, True, False, False],
|
||||
layers_per_block: int = 2,
|
||||
norm_num_groups: int = 32,
|
||||
act_fn: str = "silu",
|
||||
@@ -544,14 +544,14 @@ class AllegroDecoder3D(nn.Module):
|
||||
self,
|
||||
in_channels: int = 4,
|
||||
out_channels: int = 3,
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
up_block_types: tuple[str, ...] = (
|
||||
"AllegroUpBlock3D",
|
||||
"AllegroUpBlock3D",
|
||||
"AllegroUpBlock3D",
|
||||
"AllegroUpBlock3D",
|
||||
),
|
||||
temporal_upsample_blocks: Tuple[bool, ...] = [False, True, True, False],
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
temporal_upsample_blocks: tuple[bool, ...] = [False, True, True, False],
|
||||
block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
|
||||
layers_per_block: int = 2,
|
||||
norm_num_groups: int = 32,
|
||||
act_fn: str = "silu",
|
||||
@@ -687,14 +687,14 @@ class AutoencoderKLAllegro(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
Number of channels in the input image.
|
||||
out_channels (int, defaults to `3`):
|
||||
Number of channels in the output.
|
||||
down_block_types (`Tuple[str, ...]`, defaults to `("AllegroDownBlock3D", "AllegroDownBlock3D", "AllegroDownBlock3D", "AllegroDownBlock3D")`):
|
||||
Tuple of strings denoting which types of down blocks to use.
|
||||
up_block_types (`Tuple[str, ...]`, defaults to `("AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D")`):
|
||||
Tuple of strings denoting which types of up blocks to use.
|
||||
block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
|
||||
Tuple of integers denoting number of output channels in each block.
|
||||
temporal_downsample_blocks (`Tuple[bool, ...]`, defaults to `(True, True, False, False)`):
|
||||
Tuple of booleans denoting which blocks to enable temporal downsampling in.
|
||||
down_block_types (`tuple[str, ...]`, defaults to `("AllegroDownBlock3D", "AllegroDownBlock3D", "AllegroDownBlock3D", "AllegroDownBlock3D")`):
|
||||
tuple of strings denoting which types of down blocks to use.
|
||||
up_block_types (`tuple[str, ...]`, defaults to `("AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D")`):
|
||||
tuple of strings denoting which types of up blocks to use.
|
||||
block_out_channels (`tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
|
||||
tuple of integers denoting number of output channels in each block.
|
||||
temporal_downsample_blocks (`tuple[bool, ...]`, defaults to `(True, True, False, False)`):
|
||||
tuple of booleans denoting which blocks to enable temporal downsampling in.
|
||||
latent_channels (`int`, defaults to `4`):
|
||||
Number of channels in latents.
|
||||
layers_per_block (`int`, defaults to `2`):
|
||||
@@ -727,21 +727,21 @@ class AutoencoderKLAllegro(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
down_block_types: tuple[str, ...] = (
|
||||
"AllegroDownBlock3D",
|
||||
"AllegroDownBlock3D",
|
||||
"AllegroDownBlock3D",
|
||||
"AllegroDownBlock3D",
|
||||
),
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
up_block_types: tuple[str, ...] = (
|
||||
"AllegroUpBlock3D",
|
||||
"AllegroUpBlock3D",
|
||||
"AllegroUpBlock3D",
|
||||
"AllegroUpBlock3D",
|
||||
),
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
temporal_downsample_blocks: Tuple[bool, ...] = (True, True, False, False),
|
||||
temporal_upsample_blocks: Tuple[bool, ...] = (False, True, True, False),
|
||||
block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
|
||||
temporal_downsample_blocks: tuple[bool, ...] = (True, True, False, False),
|
||||
temporal_upsample_blocks: tuple[bool, ...] = (False, True, True, False),
|
||||
latent_channels: int = 4,
|
||||
layers_per_block: int = 2,
|
||||
act_fn: str = "silu",
|
||||
@@ -807,7 +807,7 @@ class AutoencoderKLAllegro(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]:
|
||||
r"""
|
||||
Encode a batch of videos into latents.
|
||||
|
||||
@@ -842,7 +842,7 @@ class AutoencoderKLAllegro(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
raise NotImplementedError("Decoding without tiling has not been implemented yet.")
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
"""
|
||||
Decode a batch of videos.
|
||||
|
||||
@@ -1045,7 +1045,7 @@ class AutoencoderKLAllegro(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
) -> DecoderOutput | torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.Tensor`): Input sample.
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -72,7 +72,7 @@ class CogVideoXCausalConv3d(nn.Module):
|
||||
Args:
|
||||
in_channels (`int`): Number of channels in the input tensor.
|
||||
out_channels (`int`): Number of output channels produced by the convolution.
|
||||
kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel.
|
||||
kernel_size (`int` or `tuple[int, int, int]`): Kernel size of the convolutional kernel.
|
||||
stride (`int`, defaults to `1`): Stride of the convolution.
|
||||
dilation (`int`, defaults to `1`): Dilation rate of the convolution.
|
||||
pad_mode (`str`, defaults to `"constant"`): Padding mode.
|
||||
@@ -82,7 +82,7 @@ class CogVideoXCausalConv3d(nn.Module):
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int, int]],
|
||||
kernel_size: int | tuple[int, int, int],
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
pad_mode: str = "constant",
|
||||
@@ -174,7 +174,7 @@ class CogVideoXSpatialNorm3D(nn.Module):
|
||||
self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
||||
|
||||
def forward(
|
||||
self, f: torch.Tensor, zq: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
|
||||
self, f: torch.Tensor, zq: torch.Tensor, conv_cache: Optional[dict[str, torch.Tensor]] = None
|
||||
) -> torch.Tensor:
|
||||
new_conv_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
@@ -289,7 +289,7 @@ class CogVideoXResnetBlock3D(nn.Module):
|
||||
inputs: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
zq: Optional[torch.Tensor] = None,
|
||||
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
||||
conv_cache: Optional[dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
new_conv_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
@@ -411,7 +411,7 @@ class CogVideoXDownBlock3D(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
zq: Optional[torch.Tensor] = None,
|
||||
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
||||
conv_cache: Optional[dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""Forward method of the `CogVideoXDownBlock3D` class."""
|
||||
|
||||
@@ -506,7 +506,7 @@ class CogVideoXMidBlock3D(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
zq: Optional[torch.Tensor] = None,
|
||||
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
||||
conv_cache: Optional[dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""Forward method of the `CogVideoXMidBlock3D` class."""
|
||||
|
||||
@@ -613,7 +613,7 @@ class CogVideoXUpBlock3D(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
zq: Optional[torch.Tensor] = None,
|
||||
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
||||
conv_cache: Optional[dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""Forward method of the `CogVideoXUpBlock3D` class."""
|
||||
|
||||
@@ -652,10 +652,10 @@ class CogVideoXEncoder3D(nn.Module):
|
||||
The number of input channels.
|
||||
out_channels (`int`, *optional*, defaults to 3):
|
||||
The number of output channels.
|
||||
down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
||||
down_block_types (`tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
||||
The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
|
||||
options.
|
||||
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
||||
block_out_channels (`tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
||||
The number of output channels for each block.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`):
|
||||
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
||||
@@ -671,13 +671,13 @@ class CogVideoXEncoder3D(nn.Module):
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 16,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
down_block_types: tuple[str, ...] = (
|
||||
"CogVideoXDownBlock3D",
|
||||
"CogVideoXDownBlock3D",
|
||||
"CogVideoXDownBlock3D",
|
||||
"CogVideoXDownBlock3D",
|
||||
),
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
|
||||
block_out_channels: tuple[int, ...] = (128, 256, 256, 512),
|
||||
layers_per_block: int = 3,
|
||||
act_fn: str = "silu",
|
||||
norm_eps: float = 1e-6,
|
||||
@@ -744,7 +744,7 @@ class CogVideoXEncoder3D(nn.Module):
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
||||
conv_cache: Optional[dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""The forward method of the `CogVideoXEncoder3D` class."""
|
||||
|
||||
@@ -805,9 +805,9 @@ class CogVideoXDecoder3D(nn.Module):
|
||||
The number of input channels.
|
||||
out_channels (`int`, *optional*, defaults to 3):
|
||||
The number of output channels.
|
||||
up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
||||
up_block_types (`tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
||||
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
|
||||
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
||||
block_out_channels (`tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
||||
The number of output channels for each block.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`):
|
||||
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
||||
@@ -823,13 +823,13 @@ class CogVideoXDecoder3D(nn.Module):
|
||||
self,
|
||||
in_channels: int = 16,
|
||||
out_channels: int = 3,
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
up_block_types: tuple[str, ...] = (
|
||||
"CogVideoXUpBlock3D",
|
||||
"CogVideoXUpBlock3D",
|
||||
"CogVideoXUpBlock3D",
|
||||
"CogVideoXUpBlock3D",
|
||||
),
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
|
||||
block_out_channels: tuple[int, ...] = (128, 256, 256, 512),
|
||||
layers_per_block: int = 3,
|
||||
act_fn: str = "silu",
|
||||
norm_eps: float = 1e-6,
|
||||
@@ -903,7 +903,7 @@ class CogVideoXDecoder3D(nn.Module):
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
||||
conv_cache: Optional[dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""The forward method of the `CogVideoXDecoder3D` class."""
|
||||
|
||||
@@ -966,12 +966,12 @@ class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
Parameters:
|
||||
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
||||
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
||||
Tuple of downsample block types.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
||||
Tuple of upsample block types.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
||||
Tuple of block output channels.
|
||||
down_block_types (`tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
||||
tuple of downsample block types.
|
||||
up_block_types (`tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
||||
tuple of upsample block types.
|
||||
block_out_channels (`tuple[int]`, *optional*, defaults to `(64,)`):
|
||||
tuple of block output channels.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
||||
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
|
||||
scaling_factor (`float`, *optional*, defaults to `1.15258426`):
|
||||
@@ -995,19 +995,19 @@ class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
down_block_types: Tuple[str] = (
|
||||
down_block_types: tuple[str] = (
|
||||
"CogVideoXDownBlock3D",
|
||||
"CogVideoXDownBlock3D",
|
||||
"CogVideoXDownBlock3D",
|
||||
"CogVideoXDownBlock3D",
|
||||
),
|
||||
up_block_types: Tuple[str] = (
|
||||
up_block_types: tuple[str] = (
|
||||
"CogVideoXUpBlock3D",
|
||||
"CogVideoXUpBlock3D",
|
||||
"CogVideoXUpBlock3D",
|
||||
"CogVideoXUpBlock3D",
|
||||
),
|
||||
block_out_channels: Tuple[int] = (128, 256, 256, 512),
|
||||
block_out_channels: tuple[int] = (128, 256, 256, 512),
|
||||
latent_channels: int = 16,
|
||||
layers_per_block: int = 3,
|
||||
act_fn: str = "silu",
|
||||
@@ -1018,8 +1018,8 @@ class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
sample_width: int = 720,
|
||||
scaling_factor: float = 1.15258426,
|
||||
shift_factor: Optional[float] = None,
|
||||
latents_mean: Optional[Tuple[float]] = None,
|
||||
latents_std: Optional[Tuple[float]] = None,
|
||||
latents_mean: Optional[tuple[float]] = None,
|
||||
latents_std: Optional[tuple[float]] = None,
|
||||
force_upcast: float = True,
|
||||
use_quant_conv: bool = False,
|
||||
use_post_quant_conv: bool = False,
|
||||
@@ -1153,7 +1153,7 @@ class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]:
|
||||
"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
@@ -1178,7 +1178,7 @@ class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
return (posterior,)
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = z.shape
|
||||
|
||||
if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
|
||||
@@ -1207,7 +1207,7 @@ class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
"""
|
||||
Decode a batch of images.
|
||||
|
||||
@@ -1321,7 +1321,7 @@ class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
enc = torch.cat(result_rows, dim=3)
|
||||
return enc
|
||||
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
r"""
|
||||
Decode a batch of images using a tiled decoder.
|
||||
|
||||
@@ -1410,7 +1410,7 @@ class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[torch.Tensor, torch.Tensor]:
|
||||
) -> torch.Tensor | torch.Tensor:
|
||||
x = sample
|
||||
posterior = self.encode(x).latent_dist
|
||||
if sample_posterior:
|
||||
|
||||
@@ -12,8 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -47,9 +49,9 @@ class CosmosCausalConv3d(nn.Conv3d):
|
||||
self,
|
||||
in_channels: int = 1,
|
||||
out_channels: int = 1,
|
||||
kernel_size: Union[int, Tuple[int, int, int]] = (3, 3, 3),
|
||||
dilation: Union[int, Tuple[int, int, int]] = (1, 1, 1),
|
||||
stride: Union[int, Tuple[int, int, int]] = (1, 1, 1),
|
||||
kernel_size: int | tuple[int, int, int] = (3, 3, 3),
|
||||
dilation: int | tuple[int, int, int] = (1, 1, 1),
|
||||
stride: int | tuple[int, int, int] = (1, 1, 1),
|
||||
padding: int = 1,
|
||||
pad_mode: str = "constant",
|
||||
) -> None:
|
||||
@@ -419,7 +421,7 @@ class CosmosCausalAttention(nn.Module):
|
||||
attention_head_dim: int,
|
||||
num_groups: int = 1,
|
||||
dropout: float = 0.0,
|
||||
processor: Union["CosmosSpatialAttentionProcessor2_0", "CosmosTemporalAttentionProcessor2_0"] = None,
|
||||
processor: "CosmosSpatialAttentionProcessor2_0" | "CosmosTemporalAttentionProcessor2_0" = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
@@ -711,9 +713,9 @@ class CosmosEncoder3d(nn.Module):
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 16,
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
|
||||
num_resnet_blocks: int = 2,
|
||||
attention_resolutions: Tuple[int, ...] = (32,),
|
||||
attention_resolutions: tuple[int, ...] = (32,),
|
||||
resolution: int = 1024,
|
||||
patch_size: int = 4,
|
||||
patch_type: str = "haar",
|
||||
@@ -795,9 +797,9 @@ class CosmosDecoder3d(nn.Module):
|
||||
self,
|
||||
in_channels: int = 16,
|
||||
out_channels: int = 3,
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
|
||||
num_resnet_blocks: int = 2,
|
||||
attention_resolutions: Tuple[int, ...] = (32,),
|
||||
attention_resolutions: tuple[int, ...] = (32,),
|
||||
resolution: int = 1024,
|
||||
patch_size: int = 4,
|
||||
patch_type: str = "haar",
|
||||
@@ -886,12 +888,12 @@ class AutoencoderKLCosmos(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
Number of output channels.
|
||||
latent_channels (`int`, defaults to `16`):
|
||||
Number of latent channels.
|
||||
encoder_block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
|
||||
encoder_block_out_channels (`tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
|
||||
Number of output channels for each encoder down block.
|
||||
decode_block_out_channels (`Tuple[int, ...]`, defaults to `(256, 512, 512, 512)`):
|
||||
decode_block_out_channels (`tuple[int, ...]`, defaults to `(256, 512, 512, 512)`):
|
||||
Number of output channels for each decoder up block.
|
||||
attention_resolutions (`Tuple[int, ...]`, defaults to `(32,)`):
|
||||
List of image/video resolutions at which to apply attention.
|
||||
attention_resolutions (`tuple[int, ...]`, defaults to `(32,)`):
|
||||
list of image/video resolutions at which to apply attention.
|
||||
resolution (`int`, defaults to `1024`):
|
||||
Base image/video resolution used for computing whether a block should have attention layers.
|
||||
num_layers (`int`, defaults to `2`):
|
||||
@@ -924,9 +926,9 @@ class AutoencoderKLCosmos(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
latent_channels: int = 16,
|
||||
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
decode_block_out_channels: Tuple[int, ...] = (256, 512, 512, 512),
|
||||
attention_resolutions: Tuple[int, ...] = (32,),
|
||||
encoder_block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
|
||||
decode_block_out_channels: tuple[int, ...] = (256, 512, 512, 512),
|
||||
attention_resolutions: tuple[int, ...] = (32,),
|
||||
resolution: int = 1024,
|
||||
num_layers: int = 2,
|
||||
patch_size: int = 4,
|
||||
@@ -934,8 +936,8 @@ class AutoencoderKLCosmos(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
scaling_factor: float = 1.0,
|
||||
spatial_compression_ratio: int = 8,
|
||||
temporal_compression_ratio: int = 8,
|
||||
latents_mean: Optional[List[float]] = LATENTS_MEAN,
|
||||
latents_std: Optional[List[float]] = LATENTS_STD,
|
||||
latents_mean: Optional[list[float]] = LATENTS_MEAN,
|
||||
latents_std: Optional[list[float]] = LATENTS_STD,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -1050,7 +1052,7 @@ class AutoencoderKLCosmos(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
return (posterior,)
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | tuple[torch.Tensor]:
|
||||
z = self.post_quant_conv(z)
|
||||
dec = self.decoder(z)
|
||||
|
||||
@@ -1059,7 +1061,7 @@ class AutoencoderKLCosmos(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | tuple[torch.Tensor]:
|
||||
if self.use_slicing and z.shape[0] > 1:
|
||||
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
@@ -1076,7 +1078,7 @@ class AutoencoderKLCosmos(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[Tuple[torch.Tensor], DecoderOutput]:
|
||||
) -> tuple[torch.Tensor] | DecoderOutput:
|
||||
x = sample
|
||||
posterior = self.encode(x).latent_dist
|
||||
if sample_posterior:
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -50,10 +50,10 @@ class HunyuanVideoCausalConv3d(nn.Module):
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int, int]] = 3,
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
padding: Union[int, Tuple[int, int, int]] = 0,
|
||||
dilation: Union[int, Tuple[int, int, int]] = 1,
|
||||
kernel_size: int | tuple[int, int, int] = 3,
|
||||
stride: int | tuple[int, int, int] = 1,
|
||||
padding: int | tuple[int, int, int] = 0,
|
||||
dilation: int | tuple[int, int, int] = 1,
|
||||
bias: bool = True,
|
||||
pad_mode: str = "replicate",
|
||||
) -> None:
|
||||
@@ -86,7 +86,7 @@ class HunyuanVideoUpsampleCausal3D(nn.Module):
|
||||
kernel_size: int = 3,
|
||||
stride: int = 1,
|
||||
bias: bool = True,
|
||||
upsample_factor: Tuple[float, float, float] = (2, 2, 2),
|
||||
upsample_factor: tuple[float, float, float] = (2, 2, 2),
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -357,7 +357,7 @@ class HunyuanVideoUpBlock3D(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
add_upsample: bool = True,
|
||||
upsample_scale_factor: Tuple[int, int, int] = (2, 2, 2),
|
||||
upsample_scale_factor: tuple[int, int, int] = (2, 2, 2),
|
||||
) -> None:
|
||||
super().__init__()
|
||||
resnets = []
|
||||
@@ -418,13 +418,13 @@ class HunyuanVideoEncoder3D(nn.Module):
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
down_block_types: tuple[str, ...] = (
|
||||
"HunyuanVideoDownBlock3D",
|
||||
"HunyuanVideoDownBlock3D",
|
||||
"HunyuanVideoDownBlock3D",
|
||||
"HunyuanVideoDownBlock3D",
|
||||
),
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
|
||||
layers_per_block: int = 2,
|
||||
norm_num_groups: int = 32,
|
||||
act_fn: str = "silu",
|
||||
@@ -526,13 +526,13 @@ class HunyuanVideoDecoder3D(nn.Module):
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
up_block_types: tuple[str, ...] = (
|
||||
"HunyuanVideoUpBlock3D",
|
||||
"HunyuanVideoUpBlock3D",
|
||||
"HunyuanVideoUpBlock3D",
|
||||
"HunyuanVideoUpBlock3D",
|
||||
),
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
|
||||
layers_per_block: int = 2,
|
||||
norm_num_groups: int = 32,
|
||||
act_fn: str = "silu",
|
||||
@@ -641,19 +641,19 @@ class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
latent_channels: int = 16,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
down_block_types: tuple[str, ...] = (
|
||||
"HunyuanVideoDownBlock3D",
|
||||
"HunyuanVideoDownBlock3D",
|
||||
"HunyuanVideoDownBlock3D",
|
||||
"HunyuanVideoDownBlock3D",
|
||||
),
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
up_block_types: tuple[str, ...] = (
|
||||
"HunyuanVideoUpBlock3D",
|
||||
"HunyuanVideoUpBlock3D",
|
||||
"HunyuanVideoUpBlock3D",
|
||||
"HunyuanVideoUpBlock3D",
|
||||
),
|
||||
block_out_channels: Tuple[int] = (128, 256, 512, 512),
|
||||
block_out_channels: tuple[int] = (128, 256, 512, 512),
|
||||
layers_per_block: int = 2,
|
||||
act_fn: str = "silu",
|
||||
norm_num_groups: int = 32,
|
||||
@@ -779,7 +779,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]:
|
||||
r"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
@@ -804,7 +804,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
return (posterior,)
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = z.shape
|
||||
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
||||
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
||||
@@ -825,7 +825,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
r"""
|
||||
Decode a batch of images.
|
||||
|
||||
@@ -924,7 +924,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
|
||||
return enc
|
||||
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
r"""
|
||||
Decode a batch of images using a tiled decoder.
|
||||
|
||||
@@ -1013,7 +1013,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
enc = torch.cat(result_row, dim=2)[:, :, :latent_num_frames]
|
||||
return enc
|
||||
|
||||
def _temporal_tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def _temporal_tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = z.shape
|
||||
num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
|
||||
|
||||
@@ -1055,7 +1055,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
) -> DecoderOutput | torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.Tensor`): Input sample.
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -34,9 +34,9 @@ class LTXVideoCausalConv3d(nn.Module):
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int, int]] = 3,
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
dilation: Union[int, Tuple[int, int, int]] = 1,
|
||||
kernel_size: int | tuple[int, int, int] = 3,
|
||||
stride: int | tuple[int, int, int] = 1,
|
||||
dilation: int | tuple[int, int, int] = 1,
|
||||
groups: int = 1,
|
||||
padding_mode: str = "zeros",
|
||||
is_causal: bool = True,
|
||||
@@ -201,7 +201,7 @@ class LTXVideoDownsampler3d(nn.Module):
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
stride: int | tuple[int, int, int] = 1,
|
||||
is_causal: bool = True,
|
||||
padding_mode: str = "zeros",
|
||||
) -> None:
|
||||
@@ -249,7 +249,7 @@ class LTXVideoUpsampler3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
stride: int | tuple[int, int, int] = 1,
|
||||
is_causal: bool = True,
|
||||
residual: bool = False,
|
||||
upscale_factor: int = 1,
|
||||
@@ -735,11 +735,11 @@ class LTXVideoEncoder3d(nn.Module):
|
||||
Number of input channels.
|
||||
out_channels (`int`, defaults to 128):
|
||||
Number of latent channels.
|
||||
block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
|
||||
block_out_channels (`tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
|
||||
The number of output channels for each block.
|
||||
spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`:
|
||||
spatio_temporal_scaling (`tuple[bool, ...], defaults to `(True, True, True, False)`:
|
||||
Whether a block should contain spatio-temporal downscaling layers or not.
|
||||
layers_per_block (`Tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`):
|
||||
layers_per_block (`tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`):
|
||||
The number of layers per block.
|
||||
patch_size (`int`, defaults to `4`):
|
||||
The size of spatial patches.
|
||||
@@ -755,16 +755,16 @@ class LTXVideoEncoder3d(nn.Module):
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 128,
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
|
||||
down_block_types: tuple[str, ...] = (
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
),
|
||||
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
|
||||
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
|
||||
downsample_type: Tuple[str, ...] = ("conv", "conv", "conv", "conv"),
|
||||
spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, False),
|
||||
layers_per_block: tuple[int, ...] = (4, 3, 3, 3, 4),
|
||||
downsample_type: tuple[str, ...] = ("conv", "conv", "conv", "conv"),
|
||||
patch_size: int = 4,
|
||||
patch_size_t: int = 1,
|
||||
resnet_norm_eps: float = 1e-6,
|
||||
@@ -888,11 +888,11 @@ class LTXVideoDecoder3d(nn.Module):
|
||||
Number of latent channels.
|
||||
out_channels (`int`, defaults to 3):
|
||||
Number of output channels.
|
||||
block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
|
||||
block_out_channels (`tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
|
||||
The number of output channels for each block.
|
||||
spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`:
|
||||
spatio_temporal_scaling (`tuple[bool, ...], defaults to `(True, True, True, False)`:
|
||||
Whether a block should contain spatio-temporal upscaling layers or not.
|
||||
layers_per_block (`Tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`):
|
||||
layers_per_block (`tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`):
|
||||
The number of layers per block.
|
||||
patch_size (`int`, defaults to `4`):
|
||||
The size of spatial patches.
|
||||
@@ -910,17 +910,17 @@ class LTXVideoDecoder3d(nn.Module):
|
||||
self,
|
||||
in_channels: int = 128,
|
||||
out_channels: int = 3,
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
|
||||
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
|
||||
block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
|
||||
spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, False),
|
||||
layers_per_block: tuple[int, ...] = (4, 3, 3, 3, 4),
|
||||
patch_size: int = 4,
|
||||
patch_size_t: int = 1,
|
||||
resnet_norm_eps: float = 1e-6,
|
||||
is_causal: bool = False,
|
||||
inject_noise: Tuple[bool, ...] = (False, False, False, False),
|
||||
inject_noise: tuple[bool, ...] = (False, False, False, False),
|
||||
timestep_conditioning: bool = False,
|
||||
upsample_residual: Tuple[bool, ...] = (False, False, False, False),
|
||||
upsample_factor: Tuple[bool, ...] = (1, 1, 1, 1),
|
||||
upsample_residual: tuple[bool, ...] = (False, False, False, False),
|
||||
upsample_factor: tuple[bool, ...] = (1, 1, 1, 1),
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -1049,11 +1049,11 @@ class AutoencoderKLLTXVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrigi
|
||||
Number of output channels.
|
||||
latent_channels (`int`, defaults to `128`):
|
||||
Number of latent channels.
|
||||
block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
|
||||
block_out_channels (`tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
|
||||
The number of output channels for each block.
|
||||
spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`:
|
||||
spatio_temporal_scaling (`tuple[bool, ...], defaults to `(True, True, True, False)`:
|
||||
Whether a block should contain spatio-temporal downscaling or not.
|
||||
layers_per_block (`Tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`):
|
||||
layers_per_block (`tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`):
|
||||
The number of layers per block.
|
||||
patch_size (`int`, defaults to `4`):
|
||||
The size of spatial patches.
|
||||
@@ -1082,22 +1082,22 @@ class AutoencoderKLLTXVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrigi
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
latent_channels: int = 128,
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
|
||||
down_block_types: tuple[str, ...] = (
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoDownBlock3D",
|
||||
),
|
||||
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
|
||||
decoder_layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
|
||||
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
|
||||
decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
|
||||
decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False, False),
|
||||
downsample_type: Tuple[str, ...] = ("conv", "conv", "conv", "conv"),
|
||||
upsample_residual: Tuple[bool, ...] = (False, False, False, False),
|
||||
upsample_factor: Tuple[int, ...] = (1, 1, 1, 1),
|
||||
decoder_block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
|
||||
layers_per_block: tuple[int, ...] = (4, 3, 3, 3, 4),
|
||||
decoder_layers_per_block: tuple[int, ...] = (4, 3, 3, 3, 4),
|
||||
spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, False),
|
||||
decoder_spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, False),
|
||||
decoder_inject_noise: tuple[bool, ...] = (False, False, False, False, False),
|
||||
downsample_type: tuple[str, ...] = ("conv", "conv", "conv", "conv"),
|
||||
upsample_residual: tuple[bool, ...] = (False, False, False, False),
|
||||
upsample_factor: tuple[int, ...] = (1, 1, 1, 1),
|
||||
timestep_conditioning: bool = False,
|
||||
patch_size: int = 4,
|
||||
patch_size_t: int = 1,
|
||||
@@ -1235,7 +1235,7 @@ class AutoencoderKLLTXVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrigi
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]:
|
||||
"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
@@ -1261,7 +1261,7 @@ class AutoencoderKLLTXVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrigi
|
||||
|
||||
def _decode(
|
||||
self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
) -> DecoderOutput | torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = z.shape
|
||||
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
||||
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
||||
@@ -1283,7 +1283,7 @@ class AutoencoderKLLTXVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrigi
|
||||
@apply_forward_hook
|
||||
def decode(
|
||||
self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
) -> DecoderOutput | torch.Tensor:
|
||||
"""
|
||||
Decode a batch of images.
|
||||
|
||||
@@ -1390,7 +1390,7 @@ class AutoencoderKLLTXVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrigi
|
||||
|
||||
def tiled_decode(
|
||||
self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
) -> DecoderOutput | torch.Tensor:
|
||||
r"""
|
||||
Decode a batch of images using a tiled decoder.
|
||||
|
||||
@@ -1480,7 +1480,7 @@ class AutoencoderKLLTXVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrigi
|
||||
|
||||
def _temporal_tiled_decode(
|
||||
self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
) -> DecoderOutput | torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = z.shape
|
||||
num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
|
||||
|
||||
@@ -1523,7 +1523,7 @@ class AutoencoderKLLTXVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrigi
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[torch.Tensor, torch.Tensor]:
|
||||
) -> torch.Tensor | torch.Tensor:
|
||||
x = sample
|
||||
posterior = self.encode(x).latent_dist
|
||||
if sample_posterior:
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -37,10 +37,10 @@ class EasyAnimateCausalConv3d(nn.Conv3d):
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, ...]] = 3,
|
||||
stride: Union[int, Tuple[int, ...]] = 1,
|
||||
padding: Union[int, Tuple[int, ...]] = 1,
|
||||
dilation: Union[int, Tuple[int, ...]] = 1,
|
||||
kernel_size: int | tuple[int, ...] = 3,
|
||||
stride: int | tuple[int, ...] = 1,
|
||||
padding: int | tuple[int, ...] = 1,
|
||||
dilation: int | tuple[int, ...] = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
padding_mode: str = "zeros",
|
||||
@@ -437,13 +437,13 @@ class EasyAnimateEncoder(nn.Module):
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 8,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
down_block_types: tuple[str, ...] = (
|
||||
"SpatialDownBlock3D",
|
||||
"SpatialTemporalDownBlock3D",
|
||||
"SpatialTemporalDownBlock3D",
|
||||
"SpatialTemporalDownBlock3D",
|
||||
),
|
||||
block_out_channels: Tuple[int, ...] = [128, 256, 512, 512],
|
||||
block_out_channels: tuple[int, ...] = [128, 256, 512, 512],
|
||||
layers_per_block: int = 2,
|
||||
norm_num_groups: int = 32,
|
||||
act_fn: str = "silu",
|
||||
@@ -553,13 +553,13 @@ class EasyAnimateDecoder(nn.Module):
|
||||
self,
|
||||
in_channels: int = 8,
|
||||
out_channels: int = 3,
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
up_block_types: tuple[str, ...] = (
|
||||
"SpatialUpBlock3D",
|
||||
"SpatialTemporalUpBlock3D",
|
||||
"SpatialTemporalUpBlock3D",
|
||||
"SpatialTemporalUpBlock3D",
|
||||
),
|
||||
block_out_channels: Tuple[int, ...] = [128, 256, 512, 512],
|
||||
block_out_channels: tuple[int, ...] = [128, 256, 512, 512],
|
||||
layers_per_block: int = 2,
|
||||
norm_num_groups: int = 32,
|
||||
act_fn: str = "silu",
|
||||
@@ -680,14 +680,14 @@ class AutoencoderKLMagvit(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
in_channels: int = 3,
|
||||
latent_channels: int = 16,
|
||||
out_channels: int = 3,
|
||||
block_out_channels: Tuple[int, ...] = [128, 256, 512, 512],
|
||||
down_block_types: Tuple[str, ...] = [
|
||||
block_out_channels: tuple[int, ...] = [128, 256, 512, 512],
|
||||
down_block_types: tuple[str, ...] = [
|
||||
"SpatialDownBlock3D",
|
||||
"SpatialTemporalDownBlock3D",
|
||||
"SpatialTemporalDownBlock3D",
|
||||
"SpatialTemporalDownBlock3D",
|
||||
],
|
||||
up_block_types: Tuple[str, ...] = [
|
||||
up_block_types: tuple[str, ...] = [
|
||||
"SpatialUpBlock3D",
|
||||
"SpatialTemporalUpBlock3D",
|
||||
"SpatialTemporalUpBlock3D",
|
||||
@@ -808,7 +808,7 @@ class AutoencoderKLMagvit(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
@apply_forward_hook
|
||||
def _encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]:
|
||||
"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
@@ -838,7 +838,7 @@ class AutoencoderKLMagvit(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]:
|
||||
"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
@@ -863,7 +863,7 @@ class AutoencoderKLMagvit(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
return (posterior,)
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = z.shape
|
||||
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
||||
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
||||
@@ -890,7 +890,7 @@ class AutoencoderKLMagvit(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
"""
|
||||
Decode a batch of images.
|
||||
|
||||
@@ -983,7 +983,7 @@ class AutoencoderKLMagvit(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
moments = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
|
||||
return moments
|
||||
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = z.shape
|
||||
sample_height = height * self.spatial_compression_ratio
|
||||
sample_width = width * self.spatial_compression_ratio
|
||||
@@ -1050,7 +1050,7 @@ class AutoencoderKLMagvit(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
) -> DecoderOutput | torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.Tensor`): Input sample.
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -106,7 +106,7 @@ class MochiResnetBlock3D(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
inputs: torch.Tensor,
|
||||
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
||||
conv_cache: Optional[dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
new_conv_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
@@ -193,7 +193,7 @@ class MochiDownBlock3D(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
||||
conv_cache: Optional[dict[str, torch.Tensor]] = None,
|
||||
chunk_size: int = 2**15,
|
||||
) -> torch.Tensor:
|
||||
r"""Forward method of the `MochiUpBlock3D` class."""
|
||||
@@ -294,7 +294,7 @@ class MochiMidBlock3D(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
||||
conv_cache: Optional[dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""Forward method of the `MochiMidBlock3D` class."""
|
||||
|
||||
@@ -368,7 +368,7 @@ class MochiUpBlock3D(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
||||
conv_cache: Optional[dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""Forward method of the `MochiUpBlock3D` class."""
|
||||
|
||||
@@ -445,13 +445,13 @@ class MochiEncoder3D(nn.Module):
|
||||
The number of input channels.
|
||||
out_channels (`int`, *optional*):
|
||||
The number of output channels.
|
||||
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(128, 256, 512, 768)`):
|
||||
block_out_channels (`tuple[int, ...]`, *optional*, defaults to `(128, 256, 512, 768)`):
|
||||
The number of output channels for each block.
|
||||
layers_per_block (`Tuple[int, ...]`, *optional*, defaults to `(3, 3, 4, 6, 3)`):
|
||||
layers_per_block (`tuple[int, ...]`, *optional*, defaults to `(3, 3, 4, 6, 3)`):
|
||||
The number of resnet blocks for each block.
|
||||
temporal_expansions (`Tuple[int, ...]`, *optional*, defaults to `(1, 2, 3)`):
|
||||
temporal_expansions (`tuple[int, ...]`, *optional*, defaults to `(1, 2, 3)`):
|
||||
The temporal expansion factor for each of the up blocks.
|
||||
spatial_expansions (`Tuple[int, ...]`, *optional*, defaults to `(2, 2, 2)`):
|
||||
spatial_expansions (`tuple[int, ...]`, *optional*, defaults to `(2, 2, 2)`):
|
||||
The spatial expansion factor for each of the up blocks.
|
||||
non_linearity (`str`, *optional*, defaults to `"swish"`):
|
||||
The non-linearity to use in the decoder.
|
||||
@@ -461,11 +461,11 @@ class MochiEncoder3D(nn.Module):
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 768),
|
||||
layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3),
|
||||
temporal_expansions: Tuple[int, ...] = (1, 2, 3),
|
||||
spatial_expansions: Tuple[int, ...] = (2, 2, 2),
|
||||
add_attention_block: Tuple[bool, ...] = (False, True, True, True, True),
|
||||
block_out_channels: tuple[int, ...] = (128, 256, 512, 768),
|
||||
layers_per_block: tuple[int, ...] = (3, 3, 4, 6, 3),
|
||||
temporal_expansions: tuple[int, ...] = (1, 2, 3),
|
||||
spatial_expansions: tuple[int, ...] = (2, 2, 2),
|
||||
add_attention_block: tuple[bool, ...] = (False, True, True, True, True),
|
||||
act_fn: str = "swish",
|
||||
):
|
||||
super().__init__()
|
||||
@@ -500,7 +500,7 @@ class MochiEncoder3D(nn.Module):
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
|
||||
self, hidden_states: torch.Tensor, conv_cache: Optional[dict[str, torch.Tensor]] = None
|
||||
) -> torch.Tensor:
|
||||
r"""Forward method of the `MochiEncoder3D` class."""
|
||||
|
||||
@@ -558,13 +558,13 @@ class MochiDecoder3D(nn.Module):
|
||||
The number of input channels.
|
||||
out_channels (`int`, *optional*):
|
||||
The number of output channels.
|
||||
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(128, 256, 512, 768)`):
|
||||
block_out_channels (`tuple[int, ...]`, *optional*, defaults to `(128, 256, 512, 768)`):
|
||||
The number of output channels for each block.
|
||||
layers_per_block (`Tuple[int, ...]`, *optional*, defaults to `(3, 3, 4, 6, 3)`):
|
||||
layers_per_block (`tuple[int, ...]`, *optional*, defaults to `(3, 3, 4, 6, 3)`):
|
||||
The number of resnet blocks for each block.
|
||||
temporal_expansions (`Tuple[int, ...]`, *optional*, defaults to `(1, 2, 3)`):
|
||||
temporal_expansions (`tuple[int, ...]`, *optional*, defaults to `(1, 2, 3)`):
|
||||
The temporal expansion factor for each of the up blocks.
|
||||
spatial_expansions (`Tuple[int, ...]`, *optional*, defaults to `(2, 2, 2)`):
|
||||
spatial_expansions (`tuple[int, ...]`, *optional*, defaults to `(2, 2, 2)`):
|
||||
The spatial expansion factor for each of the up blocks.
|
||||
non_linearity (`str`, *optional*, defaults to `"swish"`):
|
||||
The non-linearity to use in the decoder.
|
||||
@@ -574,10 +574,10 @@ class MochiDecoder3D(nn.Module):
|
||||
self,
|
||||
in_channels: int, # 12
|
||||
out_channels: int, # 3
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 768),
|
||||
layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3),
|
||||
temporal_expansions: Tuple[int, ...] = (1, 2, 3),
|
||||
spatial_expansions: Tuple[int, ...] = (2, 2, 2),
|
||||
block_out_channels: tuple[int, ...] = (128, 256, 512, 768),
|
||||
layers_per_block: tuple[int, ...] = (3, 3, 4, 6, 3),
|
||||
temporal_expansions: tuple[int, ...] = (1, 2, 3),
|
||||
spatial_expansions: tuple[int, ...] = (2, 2, 2),
|
||||
act_fn: str = "swish",
|
||||
):
|
||||
super().__init__()
|
||||
@@ -613,7 +613,7 @@ class MochiDecoder3D(nn.Module):
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
|
||||
self, hidden_states: torch.Tensor, conv_cache: Optional[dict[str, torch.Tensor]] = None
|
||||
) -> torch.Tensor:
|
||||
r"""Forward method of the `MochiDecoder3D` class."""
|
||||
|
||||
@@ -668,8 +668,8 @@ class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
Parameters:
|
||||
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
||||
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
||||
Tuple of block output channels.
|
||||
block_out_channels (`tuple[int]`, *optional*, defaults to `(64,)`):
|
||||
tuple of block output channels.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
||||
scaling_factor (`float`, *optional*, defaults to `1.15258426`):
|
||||
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
||||
@@ -688,15 +688,15 @@ class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
self,
|
||||
in_channels: int = 15,
|
||||
out_channels: int = 3,
|
||||
encoder_block_out_channels: Tuple[int] = (64, 128, 256, 384),
|
||||
decoder_block_out_channels: Tuple[int] = (128, 256, 512, 768),
|
||||
encoder_block_out_channels: tuple[int] = (64, 128, 256, 384),
|
||||
decoder_block_out_channels: tuple[int] = (128, 256, 512, 768),
|
||||
latent_channels: int = 12,
|
||||
layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3),
|
||||
layers_per_block: tuple[int, ...] = (3, 3, 4, 6, 3),
|
||||
act_fn: str = "silu",
|
||||
temporal_expansions: Tuple[int, ...] = (1, 2, 3),
|
||||
spatial_expansions: Tuple[int, ...] = (2, 2, 2),
|
||||
add_attention_block: Tuple[bool, ...] = (False, True, True, True, True),
|
||||
latents_mean: Tuple[float, ...] = (
|
||||
temporal_expansions: tuple[int, ...] = (1, 2, 3),
|
||||
spatial_expansions: tuple[int, ...] = (2, 2, 2),
|
||||
add_attention_block: tuple[bool, ...] = (False, True, True, True, True),
|
||||
latents_mean: tuple[float, ...] = (
|
||||
-0.06730895953510081,
|
||||
-0.038011381506090416,
|
||||
-0.07477820912866141,
|
||||
@@ -710,7 +710,7 @@ class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
-0.011931556316503654,
|
||||
-0.0321993391887285,
|
||||
),
|
||||
latents_std: Tuple[float, ...] = (
|
||||
latents_std: tuple[float, ...] = (
|
||||
0.9263795028493863,
|
||||
0.9248894543193766,
|
||||
0.9393059390890617,
|
||||
@@ -860,7 +860,7 @@ class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]:
|
||||
"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
@@ -885,7 +885,7 @@ class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
return (posterior,)
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = z.shape
|
||||
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
||||
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
||||
@@ -915,7 +915,7 @@ class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
"""
|
||||
Decode a batch of images.
|
||||
|
||||
@@ -1013,7 +1013,7 @@ class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
|
||||
return enc
|
||||
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
r"""
|
||||
Decode a batch of images using a tiled decoder.
|
||||
|
||||
@@ -1097,7 +1097,7 @@ class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[torch.Tensor, torch.Tensor]:
|
||||
) -> torch.Tensor | torch.Tensor:
|
||||
x = sample
|
||||
posterior = self.encode(x).latent_dist
|
||||
if sample_posterior:
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
# - GitHub: https://github.com/Wan-Video/Wan2.1
|
||||
# - arXiv: https://arxiv.org/abs/2503.20314
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -58,9 +58,9 @@ class QwenImageCausalConv3d(nn.Conv3d):
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int, int]],
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
padding: Union[int, Tuple[int, int, int]] = 0,
|
||||
kernel_size: int | tuple[int, int, int],
|
||||
stride: int | tuple[int, int, int] = 1,
|
||||
padding: int | tuple[int, int, int] = 0,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
@@ -679,13 +679,13 @@ class AutoencoderKLQwenImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
self,
|
||||
base_dim: int = 96,
|
||||
z_dim: int = 16,
|
||||
dim_mult: Tuple[int] = [1, 2, 4, 4],
|
||||
dim_mult: tuple[int] = [1, 2, 4, 4],
|
||||
num_res_blocks: int = 2,
|
||||
attn_scales: List[float] = [],
|
||||
temperal_downsample: List[bool] = [False, True, True],
|
||||
attn_scales: list[float] = [],
|
||||
temperal_downsample: list[bool] = [False, True, True],
|
||||
dropout: float = 0.0,
|
||||
latents_mean: List[float] = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921],
|
||||
latents_std: List[float] = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160],
|
||||
latents_mean: list[float] = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921],
|
||||
latents_std: list[float] = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160],
|
||||
) -> None:
|
||||
# fmt: on
|
||||
super().__init__()
|
||||
@@ -806,7 +806,7 @@ class AutoencoderKLQwenImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]:
|
||||
r"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
@@ -856,7 +856,7 @@ class AutoencoderKLQwenImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
return DecoderOutput(sample=out)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
r"""
|
||||
Decode a batch of images.
|
||||
|
||||
@@ -962,7 +962,7 @@ class AutoencoderKLQwenImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
|
||||
return enc
|
||||
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
r"""
|
||||
Decode a batch of images using a tiled decoder.
|
||||
|
||||
@@ -1031,7 +1031,7 @@ class AutoencoderKLQwenImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
) -> DecoderOutput | torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
sample (`torch.Tensor`): Input sample.
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import itertools
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -31,7 +31,7 @@ class TemporalDecoder(nn.Module):
|
||||
self,
|
||||
in_channels: int = 4,
|
||||
out_channels: int = 3,
|
||||
block_out_channels: Tuple[int] = (128, 256, 512, 512),
|
||||
block_out_channels: tuple[int] = (128, 256, 512, 512),
|
||||
layers_per_block: int = 2,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -145,10 +145,10 @@ class AutoencoderKLTemporalDecoder(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
Parameters:
|
||||
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
||||
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
||||
Tuple of downsample block types.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
||||
Tuple of block output channels.
|
||||
down_block_types (`tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
||||
tuple of downsample block types.
|
||||
block_out_channels (`tuple[int]`, *optional*, defaults to `(64,)`):
|
||||
tuple of block output channels.
|
||||
layers_per_block: (`int`, *optional*, defaults to 1): Number of layers per block.
|
||||
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
|
||||
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
|
||||
@@ -172,8 +172,8 @@ class AutoencoderKLTemporalDecoder(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
|
||||
block_out_channels: Tuple[int] = (64,),
|
||||
down_block_types: tuple[str] = ("DownEncoderBlock2D",),
|
||||
block_out_channels: tuple[int] = (64,),
|
||||
layers_per_block: int = 1,
|
||||
latent_channels: int = 4,
|
||||
sample_size: int = 32,
|
||||
@@ -204,7 +204,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
def attn_processors(self) -> dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
@@ -213,7 +213,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
@@ -228,7 +228,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(self, processor: AttentionProcessor | dict[str, AttentionProcessor]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -278,7 +278,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]:
|
||||
"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
@@ -308,7 +308,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
z: torch.Tensor,
|
||||
num_frames: int,
|
||||
return_dict: bool = True,
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
) -> DecoderOutput | torch.Tensor:
|
||||
"""
|
||||
Decode a batch of images.
|
||||
|
||||
@@ -339,7 +339,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
num_frames: int = 1,
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
) -> DecoderOutput | torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.Tensor`): Input sample.
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -149,9 +149,9 @@ class WanCausalConv3d(nn.Conv3d):
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int, int]],
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
padding: Union[int, Tuple[int, int, int]] = 0,
|
||||
kernel_size: int | tuple[int, int, int],
|
||||
stride: int | tuple[int, int, int] = 1,
|
||||
padding: int | tuple[int, int, int] = 0,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
@@ -971,12 +971,12 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
|
||||
base_dim: int = 96,
|
||||
decoder_base_dim: Optional[int] = None,
|
||||
z_dim: int = 16,
|
||||
dim_mult: Tuple[int] = [1, 2, 4, 4],
|
||||
dim_mult: tuple[int] = [1, 2, 4, 4],
|
||||
num_res_blocks: int = 2,
|
||||
attn_scales: List[float] = [],
|
||||
temperal_downsample: List[bool] = [False, True, True],
|
||||
attn_scales: list[float] = [],
|
||||
temperal_downsample: list[bool] = [False, True, True],
|
||||
dropout: float = 0.0,
|
||||
latents_mean: List[float] = [
|
||||
latents_mean: list[float] = [
|
||||
-0.7571,
|
||||
-0.7089,
|
||||
-0.9113,
|
||||
@@ -994,7 +994,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
|
||||
0.2503,
|
||||
-0.2921,
|
||||
],
|
||||
latents_std: List[float] = [
|
||||
latents_std: list[float] = [
|
||||
2.8184,
|
||||
1.4541,
|
||||
2.3275,
|
||||
@@ -1153,7 +1153,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]:
|
||||
r"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
@@ -1209,7 +1209,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
|
||||
return DecoderOutput(sample=out)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
r"""
|
||||
Decode a batch of images.
|
||||
|
||||
@@ -1315,7 +1315,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
|
||||
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
|
||||
return enc
|
||||
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
|
||||
r"""
|
||||
Decode a batch of images using a tiled decoder.
|
||||
|
||||
@@ -1399,7 +1399,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
) -> DecoderOutput | torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
sample (`torch.Tensor`): Input sample.
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -303,9 +303,9 @@ class AutoencoderOobleck(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
Parameters:
|
||||
encoder_hidden_size (`int`, *optional*, defaults to 128):
|
||||
Intermediate representation dimension for the encoder.
|
||||
downsampling_ratios (`List[int]`, *optional*, defaults to `[2, 4, 4, 8, 8]`):
|
||||
downsampling_ratios (`list[int]`, *optional*, defaults to `[2, 4, 4, 8, 8]`):
|
||||
Ratios for downsampling in the encoder. These are used in reverse order for upsampling in the decoder.
|
||||
channel_multiples (`List[int]`, *optional*, defaults to `[1, 2, 4, 8, 16]`):
|
||||
channel_multiples (`list[int]`, *optional*, defaults to `[1, 2, 4, 8, 16]`):
|
||||
Multiples used to determine the hidden sizes of the hidden layers.
|
||||
decoder_channels (`int`, *optional*, defaults to 128):
|
||||
Intermediate representation dimension for the decoder.
|
||||
@@ -360,7 +360,7 @@ class AutoencoderOobleck(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderOobleckOutput, Tuple[OobleckDiagonalGaussianDistribution]]:
|
||||
) -> AutoencoderOobleckOutput | tuple[OobleckDiagonalGaussianDistribution]:
|
||||
"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
@@ -386,7 +386,7 @@ class AutoencoderOobleck(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
|
||||
return AutoencoderOobleckOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[OobleckDecoderOutput, torch.Tensor]:
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> OobleckDecoderOutput | torch.Tensor:
|
||||
dec = self.decoder(z)
|
||||
|
||||
if not return_dict:
|
||||
@@ -397,7 +397,7 @@ class AutoencoderOobleck(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
@apply_forward_hook
|
||||
def decode(
|
||||
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
|
||||
) -> Union[OobleckDecoderOutput, torch.FloatTensor]:
|
||||
) -> OobleckDecoderOutput | torch.FloatTensor:
|
||||
"""
|
||||
Decode a batch of images.
|
||||
|
||||
@@ -429,7 +429,7 @@ class AutoencoderOobleck(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[OobleckDecoderOutput, torch.Tensor]:
|
||||
) -> OobleckDecoderOutput | torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.Tensor`): Input sample.
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -50,11 +50,11 @@ class AutoencoderTiny(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
Parameters:
|
||||
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
|
||||
out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
|
||||
encoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`):
|
||||
Tuple of integers representing the number of output channels for each encoder block. The length of the
|
||||
encoder_block_out_channels (`tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`):
|
||||
tuple of integers representing the number of output channels for each encoder block. The length of the
|
||||
tuple should be equal to the number of encoder blocks.
|
||||
decoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`):
|
||||
Tuple of integers representing the number of output channels for each decoder block. The length of the
|
||||
decoder_block_out_channels (`tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`):
|
||||
tuple of integers representing the number of output channels for each decoder block. The length of the
|
||||
tuple should be equal to the number of decoder blocks.
|
||||
act_fn (`str`, *optional*, defaults to `"relu"`):
|
||||
Activation function to be used throughout the model.
|
||||
@@ -64,12 +64,12 @@ class AutoencoderTiny(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
upsampling_scaling_factor (`int`, *optional*, defaults to 2):
|
||||
Scaling factor for upsampling in the decoder. It determines the size of the output image during the
|
||||
upsampling process.
|
||||
num_encoder_blocks (`Tuple[int]`, *optional*, defaults to `(1, 3, 3, 3)`):
|
||||
Tuple of integers representing the number of encoder blocks at each stage of the encoding process. The
|
||||
num_encoder_blocks (`tuple[int]`, *optional*, defaults to `(1, 3, 3, 3)`):
|
||||
tuple of integers representing the number of encoder blocks at each stage of the encoding process. The
|
||||
length of the tuple should be equal to the number of stages in the encoder. Each stage has a different
|
||||
number of encoder blocks.
|
||||
num_decoder_blocks (`Tuple[int]`, *optional*, defaults to `(3, 3, 3, 1)`):
|
||||
Tuple of integers representing the number of decoder blocks at each stage of the decoding process. The
|
||||
num_decoder_blocks (`tuple[int]`, *optional*, defaults to `(3, 3, 3, 1)`):
|
||||
tuple of integers representing the number of decoder blocks at each stage of the decoding process. The
|
||||
length of the tuple should be equal to the number of stages in the decoder. Each stage has a different
|
||||
number of decoder blocks.
|
||||
latent_magnitude (`float`, *optional*, defaults to 3.0):
|
||||
@@ -99,14 +99,14 @@ class AutoencoderTiny(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
encoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
|
||||
decoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
|
||||
encoder_block_out_channels: tuple[int, ...] = (64, 64, 64, 64),
|
||||
decoder_block_out_channels: tuple[int, ...] = (64, 64, 64, 64),
|
||||
act_fn: str = "relu",
|
||||
upsample_fn: str = "nearest",
|
||||
latent_channels: int = 4,
|
||||
upsampling_scaling_factor: int = 2,
|
||||
num_encoder_blocks: Tuple[int, ...] = (1, 3, 3, 3),
|
||||
num_decoder_blocks: Tuple[int, ...] = (3, 3, 3, 1),
|
||||
num_encoder_blocks: tuple[int, ...] = (1, 3, 3, 3),
|
||||
num_decoder_blocks: tuple[int, ...] = (3, 3, 3, 1),
|
||||
latent_magnitude: int = 3,
|
||||
latent_shift: float = 0.5,
|
||||
force_upcast: bool = False,
|
||||
@@ -258,7 +258,7 @@ class AutoencoderTiny(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
return out
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[AutoencoderTinyOutput, Tuple[torch.Tensor]]:
|
||||
def encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderTinyOutput | tuple[torch.Tensor]:
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
output = [
|
||||
self._tiled_encode(x_slice) if self.use_tiling else self.encoder(x_slice) for x_slice in x.split(1)
|
||||
@@ -275,7 +275,7 @@ class AutoencoderTiny(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
@apply_forward_hook
|
||||
def decode(
|
||||
self, x: torch.Tensor, generator: Optional[torch.Generator] = None, return_dict: bool = True
|
||||
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
|
||||
) -> DecoderOutput | tuple[torch.Tensor]:
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
output = [
|
||||
self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x_slice) for x_slice in x.split(1)
|
||||
@@ -293,7 +293,7 @@ class AutoencoderTiny(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
|
||||
) -> DecoderOutput | tuple[torch.Tensor]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.Tensor`): Input sample.
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -77,9 +77,9 @@ class ConsistencyDecoderVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
latent_channels: int = 4,
|
||||
sample_size: int = 32,
|
||||
encoder_act_fn: str = "silu",
|
||||
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
encoder_block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
|
||||
encoder_double_z: bool = True,
|
||||
encoder_down_block_types: Tuple[str, ...] = (
|
||||
encoder_down_block_types: tuple[str, ...] = (
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
@@ -90,8 +90,8 @@ class ConsistencyDecoderVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
encoder_norm_num_groups: int = 32,
|
||||
encoder_out_channels: int = 4,
|
||||
decoder_add_attention: bool = False,
|
||||
decoder_block_out_channels: Tuple[int, ...] = (320, 640, 1024, 1024),
|
||||
decoder_down_block_types: Tuple[str, ...] = (
|
||||
decoder_block_out_channels: tuple[int, ...] = (320, 640, 1024, 1024),
|
||||
decoder_down_block_types: tuple[str, ...] = (
|
||||
"ResnetDownsampleBlock2D",
|
||||
"ResnetDownsampleBlock2D",
|
||||
"ResnetDownsampleBlock2D",
|
||||
@@ -106,7 +106,7 @@ class ConsistencyDecoderVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
decoder_out_channels: int = 6,
|
||||
decoder_resnet_time_scale_shift: str = "scale_shift",
|
||||
decoder_time_embedding_type: str = "learned",
|
||||
decoder_up_block_types: Tuple[str, ...] = (
|
||||
decoder_up_block_types: tuple[str, ...] = (
|
||||
"ResnetUpsampleBlock2D",
|
||||
"ResnetUpsampleBlock2D",
|
||||
"ResnetUpsampleBlock2D",
|
||||
@@ -169,7 +169,7 @@ class ConsistencyDecoderVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
def attn_processors(self) -> dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
@@ -178,7 +178,7 @@ class ConsistencyDecoderVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
@@ -193,7 +193,7 @@ class ConsistencyDecoderVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(self, processor: AttentionProcessor | dict[str, AttentionProcessor]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -246,7 +246,7 @@ class ConsistencyDecoderVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[ConsistencyDecoderVAEOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
) -> ConsistencyDecoderVAEOutput | tuple[DiagonalGaussianDistribution]:
|
||||
"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
@@ -285,7 +285,7 @@ class ConsistencyDecoderVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
num_inference_steps: int = 2,
|
||||
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
|
||||
) -> DecoderOutput | tuple[torch.Tensor]:
|
||||
"""
|
||||
Decodes the input latent vector `z` using the consistency decoder VAE model.
|
||||
|
||||
@@ -296,7 +296,7 @@ class ConsistencyDecoderVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
num_inference_steps (int): The number of inference steps. Default is 2.
|
||||
|
||||
Returns:
|
||||
Union[DecoderOutput, Tuple[torch.Tensor]]: The decoded output.
|
||||
Union[DecoderOutput, tuple[torch.Tensor]]: The decoded output.
|
||||
|
||||
"""
|
||||
z = (z * self.config.scaling_factor - self.means) / self.stds
|
||||
@@ -339,7 +339,7 @@ class ConsistencyDecoderVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
|
||||
return b
|
||||
|
||||
def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[ConsistencyDecoderVAEOutput, Tuple]:
|
||||
def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> ConsistencyDecoderVAEOutput | tuple:
|
||||
r"""Encode a batch of images using a tiled encoder.
|
||||
|
||||
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
||||
@@ -400,7 +400,7 @@ class ConsistencyDecoderVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
|
||||
) -> DecoderOutput | tuple[torch.Tensor]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.Tensor`): Input sample.
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -66,10 +66,10 @@ class Encoder(nn.Module):
|
||||
The number of input channels.
|
||||
out_channels (`int`, *optional*, defaults to 3):
|
||||
The number of output channels.
|
||||
down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
||||
down_block_types (`tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
||||
The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
|
||||
options.
|
||||
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
||||
block_out_channels (`tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
||||
The number of output channels for each block.
|
||||
layers_per_block (`int`, *optional*, defaults to 2):
|
||||
The number of layers per block.
|
||||
@@ -85,8 +85,8 @@ class Encoder(nn.Module):
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
|
||||
block_out_channels: Tuple[int, ...] = (64,),
|
||||
down_block_types: tuple[str, ...] = ("DownEncoderBlock2D",),
|
||||
block_out_channels: tuple[int, ...] = (64,),
|
||||
layers_per_block: int = 2,
|
||||
norm_num_groups: int = 32,
|
||||
act_fn: str = "silu",
|
||||
@@ -187,9 +187,9 @@ class Decoder(nn.Module):
|
||||
The number of input channels.
|
||||
out_channels (`int`, *optional*, defaults to 3):
|
||||
The number of output channels.
|
||||
up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
||||
up_block_types (`tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
||||
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
|
||||
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
||||
block_out_channels (`tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
||||
The number of output channels for each block.
|
||||
layers_per_block (`int`, *optional*, defaults to 2):
|
||||
The number of layers per block.
|
||||
@@ -205,8 +205,8 @@ class Decoder(nn.Module):
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
|
||||
block_out_channels: Tuple[int, ...] = (64,),
|
||||
up_block_types: tuple[str, ...] = ("UpDecoderBlock2D",),
|
||||
block_out_channels: tuple[int, ...] = (64,),
|
||||
layers_per_block: int = 2,
|
||||
norm_num_groups: int = 32,
|
||||
act_fn: str = "silu",
|
||||
@@ -402,9 +402,9 @@ class MaskConditionDecoder(nn.Module):
|
||||
The number of input channels.
|
||||
out_channels (`int`, *optional*, defaults to 3):
|
||||
The number of output channels.
|
||||
up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
||||
up_block_types (`tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
||||
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
|
||||
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
||||
block_out_channels (`tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
||||
The number of output channels for each block.
|
||||
layers_per_block (`int`, *optional*, defaults to 2):
|
||||
The number of layers per block.
|
||||
@@ -420,8 +420,8 @@ class MaskConditionDecoder(nn.Module):
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
|
||||
block_out_channels: Tuple[int, ...] = (64,),
|
||||
up_block_types: tuple[str, ...] = ("UpDecoderBlock2D",),
|
||||
block_out_channels: tuple[int, ...] = (64,),
|
||||
layers_per_block: int = 2,
|
||||
norm_num_groups: int = 32,
|
||||
act_fn: str = "silu",
|
||||
@@ -633,7 +633,7 @@ class VectorQuantizer(nn.Module):
|
||||
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
|
||||
return back.reshape(ishape)
|
||||
|
||||
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Tuple]:
|
||||
def forward(self, z: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, tuple]:
|
||||
# reshape z -> (batch, height, width, channel) and flatten
|
||||
z = z.permute(0, 2, 3, 1).contiguous()
|
||||
z_flattened = z.view(-1, self.vq_embed_dim)
|
||||
@@ -667,7 +667,7 @@ class VectorQuantizer(nn.Module):
|
||||
|
||||
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
||||
|
||||
def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...]) -> torch.Tensor:
|
||||
def get_codebook_entry(self, indices: torch.LongTensor, shape: tuple[int, ...]) -> torch.Tensor:
|
||||
# shape specifying (batch, height, width, channel)
|
||||
if self.remap is not None:
|
||||
indices = indices.reshape(shape[0], -1) # add batch axis
|
||||
@@ -728,7 +728,7 @@ class DiagonalGaussianDistribution(object):
|
||||
dim=[1, 2, 3],
|
||||
)
|
||||
|
||||
def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
|
||||
def nll(self, sample: torch.Tensor, dims: tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.0])
|
||||
logtwopi = np.log(2.0 * np.pi)
|
||||
@@ -761,10 +761,10 @@ class EncoderTiny(nn.Module):
|
||||
The number of input channels.
|
||||
out_channels (`int`):
|
||||
The number of output channels.
|
||||
num_blocks (`Tuple[int, ...]`):
|
||||
num_blocks (`tuple[int, ...]`):
|
||||
Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to
|
||||
use.
|
||||
block_out_channels (`Tuple[int, ...]`):
|
||||
block_out_channels (`tuple[int, ...]`):
|
||||
The number of output channels for each block.
|
||||
act_fn (`str`):
|
||||
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
||||
@@ -774,8 +774,8 @@ class EncoderTiny(nn.Module):
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_blocks: Tuple[int, ...],
|
||||
block_out_channels: Tuple[int, ...],
|
||||
num_blocks: tuple[int, ...],
|
||||
block_out_channels: tuple[int, ...],
|
||||
act_fn: str,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -827,10 +827,10 @@ class DecoderTiny(nn.Module):
|
||||
The number of input channels.
|
||||
out_channels (`int`):
|
||||
The number of output channels.
|
||||
num_blocks (`Tuple[int, ...]`):
|
||||
num_blocks (`tuple[int, ...]`):
|
||||
Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to
|
||||
use.
|
||||
block_out_channels (`Tuple[int, ...]`):
|
||||
block_out_channels (`tuple[int, ...]`):
|
||||
The number of output channels for each block.
|
||||
upsampling_scaling_factor (`int`):
|
||||
The scaling factor to use for upsampling.
|
||||
@@ -842,8 +842,8 @@ class DecoderTiny(nn.Module):
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_blocks: Tuple[int, ...],
|
||||
block_out_channels: Tuple[int, ...],
|
||||
num_blocks: tuple[int, ...],
|
||||
block_out_channels: tuple[int, ...],
|
||||
upsampling_scaling_factor: int,
|
||||
act_fn: str,
|
||||
upsample_fn: str,
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -48,12 +48,12 @@ class VQModel(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
Parameters:
|
||||
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
||||
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
||||
Tuple of downsample block types.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
||||
Tuple of upsample block types.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
||||
Tuple of block output channels.
|
||||
down_block_types (`tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
||||
tuple of downsample block types.
|
||||
up_block_types (`tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
||||
tuple of upsample block types.
|
||||
block_out_channels (`tuple[int]`, *optional*, defaults to `(64,)`):
|
||||
tuple of block output channels.
|
||||
layers_per_block (`int`, *optional*, defaults to `1`): Number of layers per block.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
||||
latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
|
||||
@@ -80,9 +80,9 @@ class VQModel(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
|
||||
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
|
||||
block_out_channels: Tuple[int, ...] = (64,),
|
||||
down_block_types: tuple[str, ...] = ("DownEncoderBlock2D",),
|
||||
up_block_types: tuple[str, ...] = ("UpDecoderBlock2D",),
|
||||
block_out_channels: tuple[int, ...] = (64,),
|
||||
layers_per_block: int = 1,
|
||||
act_fn: str = "silu",
|
||||
latent_channels: int = 3,
|
||||
@@ -143,7 +143,7 @@ class VQModel(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
@apply_forward_hook
|
||||
def decode(
|
||||
self, h: torch.Tensor, force_not_quantize: bool = False, return_dict: bool = True, shape=None
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
) -> DecoderOutput | torch.Tensor:
|
||||
# also go through quantization layer
|
||||
if not force_not_quantize:
|
||||
quant, commit_loss, _ = self.quantize(h)
|
||||
@@ -161,9 +161,7 @@ class VQModel(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
|
||||
return DecoderOutput(sample=dec, commit_loss=commit_loss)
|
||||
|
||||
def forward(
|
||||
self, sample: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[DecoderOutput, Tuple[torch.Tensor, ...]]:
|
||||
def forward(self, sample: torch.Tensor, return_dict: bool = True) -> DecoderOutput | tuple[torch.Tensor, ...]:
|
||||
r"""
|
||||
The [`VQModel`] forward method.
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
from ..utils import deprecate
|
||||
from .controlnets.controlnet import ( # noqa
|
||||
@@ -36,15 +36,15 @@ class ControlNetModel(ControlNetModel):
|
||||
conditioning_channels: int = 3,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
down_block_types: tuple[str, ...] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
only_cross_attention: bool | tuple[bool] = False,
|
||||
block_out_channels: tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
@@ -52,11 +52,11 @@ class ControlNetModel(ControlNetModel):
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 1280,
|
||||
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
transformer_layers_per_block: int | tuple[int, ...] = 1,
|
||||
encoder_hid_dim: Optional[int] = None,
|
||||
encoder_hid_dim_type: Optional[str] = None,
|
||||
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
attention_head_dim: int | tuple[int, ...] = 8,
|
||||
num_attention_heads: Optional[int | tuple[int, ...]] = None,
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
addition_embed_type: Optional[str] = None,
|
||||
@@ -66,7 +66,7 @@ class ControlNetModel(ControlNetModel):
|
||||
resnet_time_scale_shift: str = "default",
|
||||
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
conditioning_embedding_out_channels: Optional[tuple[int, ...]] = (16, 32, 96, 256),
|
||||
global_pool_conditions: bool = False,
|
||||
addition_embed_type_num_heads: int = 64,
|
||||
):
|
||||
|
||||
@@ -13,8 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import List
|
||||
|
||||
from ..utils import deprecate, logging
|
||||
from .controlnets.controlnet_flux import FluxControlNetModel, FluxControlNetOutput, FluxMultiControlNetModel
|
||||
|
||||
@@ -41,7 +39,7 @@ class FluxControlNetModel(FluxControlNetModel):
|
||||
joint_attention_dim: int = 4096,
|
||||
pooled_projection_dim: int = 768,
|
||||
guidance_embeds: bool = False,
|
||||
axes_dims_rope: List[int] = [16, 56, 56],
|
||||
axes_dims_rope: list[int] = [16, 56, 56],
|
||||
num_mode: int = None,
|
||||
conditioning_embedding_channels: int = None,
|
||||
):
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
from ..utils import deprecate, logging
|
||||
from .controlnets.controlnet_sparsectrl import ( # noqa
|
||||
@@ -50,14 +50,14 @@ class SparseControlNetModel(SparseControlNetModel):
|
||||
conditioning_channels: int = 4,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
down_block_types: tuple[str, ...] = (
|
||||
"CrossAttnDownBlockMotion",
|
||||
"CrossAttnDownBlockMotion",
|
||||
"CrossAttnDownBlockMotion",
|
||||
"DownBlockMotion",
|
||||
),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
only_cross_attention: bool | tuple[bool] = False,
|
||||
block_out_channels: tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
@@ -65,15 +65,15 @@ class SparseControlNetModel(SparseControlNetModel):
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 768,
|
||||
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = None,
|
||||
temporal_transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
transformer_layers_per_block: int | tuple[int, ...] = 1,
|
||||
transformer_layers_per_mid_block: Optional[int | tuple[int]] = None,
|
||||
temporal_transformer_layers_per_block: int | tuple[int, ...] = 1,
|
||||
attention_head_dim: int | tuple[int, ...] = 8,
|
||||
num_attention_heads: Optional[int | tuple[int, ...]] = None,
|
||||
use_linear_projection: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
conditioning_embedding_out_channels: Optional[tuple[int, ...]] = (16, 32, 96, 256),
|
||||
global_pool_conditions: bool = False,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
motion_max_seq_length: int = 32,
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -57,7 +57,7 @@ class ControlNetOutput(BaseOutput):
|
||||
Output can be used to condition the original UNet's middle block activation.
|
||||
"""
|
||||
|
||||
down_block_res_samples: Tuple[torch.Tensor]
|
||||
down_block_res_samples: tuple[torch.Tensor]
|
||||
mid_block_res_sample: torch.Tensor
|
||||
|
||||
|
||||
@@ -75,7 +75,7 @@ class ControlNetConditioningEmbedding(nn.Module):
|
||||
self,
|
||||
conditioning_embedding_channels: int,
|
||||
conditioning_channels: int = 3,
|
||||
block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
|
||||
block_out_channels: tuple[int, ...] = (16, 32, 96, 256),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -119,7 +119,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
The frequency shift to apply to the time embedding.
|
||||
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
||||
The tuple of downsample blocks to use.
|
||||
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
||||
only_cross_attention (`Union[bool, tuple[bool]]`, defaults to `False`):
|
||||
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
||||
The tuple of output channels for each block.
|
||||
layers_per_block (`int`, defaults to 2):
|
||||
@@ -137,7 +137,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
The epsilon to use for the normalization.
|
||||
cross_attention_dim (`int`, defaults to 1280):
|
||||
The dimension of the cross attention features.
|
||||
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
||||
transformer_layers_per_block (`int` or `tuple[int]`, *optional*, defaults to 1):
|
||||
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
||||
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
||||
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
||||
@@ -147,7 +147,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
||||
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
||||
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
||||
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
|
||||
attention_head_dim (`Union[int, tuple[int]]`, defaults to 8):
|
||||
The dimension of the attention heads.
|
||||
use_linear_projection (`bool`, defaults to `False`):
|
||||
class_embed_type (`str`, *optional*, defaults to `None`):
|
||||
@@ -184,15 +184,15 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
conditioning_channels: int = 3,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
down_block_types: tuple[str, ...] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
only_cross_attention: bool | tuple[bool] = False,
|
||||
block_out_channels: tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
@@ -200,11 +200,11 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 1280,
|
||||
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
transformer_layers_per_block: int | tuple[int, ...] = 1,
|
||||
encoder_hid_dim: Optional[int] = None,
|
||||
encoder_hid_dim_type: Optional[str] = None,
|
||||
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
attention_head_dim: int | tuple[int, ...] = 8,
|
||||
num_attention_heads: Optional[int | tuple[int, ...]] = None,
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
addition_embed_type: Optional[str] = None,
|
||||
@@ -214,7 +214,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
resnet_time_scale_shift: str = "default",
|
||||
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
conditioning_embedding_out_channels: Optional[tuple[int, ...]] = (16, 32, 96, 256),
|
||||
global_pool_conditions: bool = False,
|
||||
addition_embed_type_num_heads: int = 64,
|
||||
):
|
||||
@@ -444,7 +444,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
cls,
|
||||
unet: UNet2DConditionModel,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
conditioning_embedding_out_channels: Optional[tuple[int, ...]] = (16, 32, 96, 256),
|
||||
load_weights_from_unet: bool = True,
|
||||
conditioning_channels: int = 3,
|
||||
):
|
||||
@@ -517,7 +517,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
def attn_processors(self) -> dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
@@ -526,7 +526,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
@@ -541,7 +541,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(self, processor: AttentionProcessor | dict[str, AttentionProcessor]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -592,7 +592,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
self.set_attn_processor(processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
|
||||
def set_attention_slice(self, slice_size: str | int | list[int]) -> None:
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
@@ -646,7 +646,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
# Recursively walk through all the children.
|
||||
# Any children which exposes the set_attention_slice method
|
||||
# gets the message
|
||||
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
||||
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: list[int]):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
module.set_attention_slice(slice_size.pop())
|
||||
|
||||
@@ -660,18 +660,18 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
timestep: torch.Tensor | float | int,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
controlnet_cond: torch.Tensor,
|
||||
conditioning_scale: float = 1.0,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
added_cond_kwargs: Optional[dict[str, torch.Tensor]] = None,
|
||||
cross_attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
guess_mode: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
|
||||
) -> ControlNetOutput | tuple[tuple[torch.Tensor, ...], torch.Tensor]:
|
||||
"""
|
||||
The [`ControlNetModel`] forward method.
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import flax
|
||||
import flax.linen as nn
|
||||
@@ -49,7 +49,7 @@ class FlaxControlNetOutput(BaseOutput):
|
||||
|
||||
class FlaxControlNetConditioningEmbedding(nn.Module):
|
||||
conditioning_embedding_channels: int
|
||||
block_out_channels: Tuple[int, ...] = (16, 32, 96, 256)
|
||||
block_out_channels: tuple[int, ...] = (16, 32, 96, 256)
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self) -> None:
|
||||
@@ -132,15 +132,15 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
The size of the input sample.
|
||||
in_channels (`int`, *optional*, defaults to 4):
|
||||
The number of channels in the input sample.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`):
|
||||
down_block_types (`tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`):
|
||||
The tuple of downsample blocks to use.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
||||
block_out_channels (`tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
||||
The tuple of output channels for each block.
|
||||
layers_per_block (`int`, *optional*, defaults to 2):
|
||||
The number of layers per block.
|
||||
attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
|
||||
attention_head_dim (`int` or `tuple[int]`, *optional*, defaults to 8):
|
||||
The dimension of the attention heads.
|
||||
num_attention_heads (`int` or `Tuple[int]`, *optional*):
|
||||
num_attention_heads (`int` or `tuple[int]`, *optional*):
|
||||
The number of attention heads.
|
||||
cross_attention_dim (`int`, *optional*, defaults to 768):
|
||||
The dimension of the cross attention features.
|
||||
@@ -157,17 +157,17 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
|
||||
sample_size: int = 32
|
||||
in_channels: int = 4
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
down_block_types: tuple[str, ...] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
)
|
||||
only_cross_attention: Union[bool, Tuple[bool, ...]] = False
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280)
|
||||
only_cross_attention: bool | tuple[bool, ...] = False
|
||||
block_out_channels: tuple[int, ...] = (320, 640, 1280, 1280)
|
||||
layers_per_block: int = 2
|
||||
attention_head_dim: Union[int, Tuple[int, ...]] = 8
|
||||
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None
|
||||
attention_head_dim: int | tuple[int, ...] = 8
|
||||
num_attention_heads: Optional[int | tuple[int, ...]] = None
|
||||
cross_attention_dim: int = 1280
|
||||
dropout: float = 0.0
|
||||
use_linear_projection: bool = False
|
||||
@@ -175,7 +175,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
flip_sin_to_cos: bool = True
|
||||
freq_shift: int = 0
|
||||
controlnet_conditioning_channel_order: str = "rgb"
|
||||
conditioning_embedding_out_channels: Tuple[int, ...] = (16, 32, 96, 256)
|
||||
conditioning_embedding_out_channels: tuple[int, ...] = (16, 32, 96, 256)
|
||||
|
||||
def init_weights(self, rng: jax.Array) -> FrozenDict:
|
||||
# init input tensors
|
||||
@@ -327,13 +327,13 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
def __call__(
|
||||
self,
|
||||
sample: jnp.ndarray,
|
||||
timesteps: Union[jnp.ndarray, float, int],
|
||||
timesteps: jnp.ndarray | float | int,
|
||||
encoder_hidden_states: jnp.ndarray,
|
||||
controlnet_cond: jnp.ndarray,
|
||||
conditioning_scale: float = 1.0,
|
||||
return_dict: bool = True,
|
||||
train: bool = False,
|
||||
) -> Union[FlaxControlNetOutput, Tuple[Tuple[jnp.ndarray, ...], jnp.ndarray]]:
|
||||
) -> FlaxControlNetOutput | tuple[tuple[jnp.ndarray, ...], jnp.ndarray]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -34,8 +34,8 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@dataclass
|
||||
class FluxControlNetOutput(BaseOutput):
|
||||
controlnet_block_samples: Tuple[torch.Tensor]
|
||||
controlnet_single_block_samples: Tuple[torch.Tensor]
|
||||
controlnet_block_samples: tuple[torch.Tensor]
|
||||
controlnet_single_block_samples: tuple[torch.Tensor]
|
||||
|
||||
|
||||
class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
@@ -53,7 +53,7 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
joint_attention_dim: int = 4096,
|
||||
pooled_projection_dim: int = 768,
|
||||
guidance_embeds: bool = False,
|
||||
axes_dims_rope: List[int] = [16, 56, 56],
|
||||
axes_dims_rope: list[int] = [16, 56, 56],
|
||||
num_mode: int = None,
|
||||
conditioning_embedding_channels: int = None,
|
||||
):
|
||||
@@ -129,7 +129,7 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
@@ -222,9 +222,9 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
img_ids: torch.Tensor = None,
|
||||
txt_ids: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
joint_attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
) -> torch.FloatTensor | Transformer2DModelOutput:
|
||||
"""
|
||||
The [`FluxTransformer2DModel`] forward method.
|
||||
|
||||
@@ -404,7 +404,7 @@ class FluxMultiControlNetModel(ModelMixin):
|
||||
compatible with `FluxControlNetModel`.
|
||||
|
||||
Args:
|
||||
controlnets (`List[FluxControlNetModel]`):
|
||||
controlnets (`list[FluxControlNetModel]`):
|
||||
Provides additional conditioning to the unet during the denoising process. You must set multiple
|
||||
`FluxControlNetModel` as a list.
|
||||
"""
|
||||
@@ -416,18 +416,18 @@ class FluxMultiControlNetModel(ModelMixin):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
controlnet_cond: List[torch.tensor],
|
||||
controlnet_mode: List[torch.tensor],
|
||||
conditioning_scale: List[float],
|
||||
controlnet_cond: list[torch.tensor],
|
||||
controlnet_mode: list[torch.tensor],
|
||||
conditioning_scale: list[float],
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
pooled_projections: torch.Tensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
img_ids: torch.Tensor = None,
|
||||
txt_ids: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
joint_attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[FluxControlNetOutput, Tuple]:
|
||||
) -> FluxControlNetOutput | tuple:
|
||||
# ControlNet-Union with multiple conditions
|
||||
# only load one ControlNet for saving memories
|
||||
if len(self.nets) == 1:
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -27,7 +27,7 @@ from ..embeddings import (
|
||||
)
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..transformers.hunyuan_transformer_2d import HunyuanDiTBlock
|
||||
from .controlnet import Tuple, zero_module
|
||||
from .controlnet import zero_module
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -35,7 +35,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@dataclass
|
||||
class HunyuanControlNetOutput(BaseOutput):
|
||||
controlnet_block_samples: Tuple[torch.Tensor]
|
||||
controlnet_block_samples: tuple[torch.Tensor]
|
||||
|
||||
|
||||
class HunyuanDiT2DControlNetModel(ModelMixin, ConfigMixin):
|
||||
@@ -116,7 +116,7 @@ class HunyuanDiT2DControlNetModel(ModelMixin, ConfigMixin):
|
||||
self.controlnet_blocks.append(controlnet_block)
|
||||
|
||||
@property
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
def attn_processors(self) -> dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
@@ -125,7 +125,7 @@ class HunyuanDiT2DControlNetModel(ModelMixin, ConfigMixin):
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
||||
|
||||
@@ -139,7 +139,7 @@ class HunyuanDiT2DControlNetModel(ModelMixin, ConfigMixin):
|
||||
|
||||
return processors
|
||||
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(self, processor: AttentionProcessor | dict[str, AttentionProcessor]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -317,7 +317,7 @@ class HunyuanDiT2DMultiControlNetModel(ModelMixin):
|
||||
designed to be compatible with `HunyuanDiT2DControlNetModel`.
|
||||
|
||||
Args:
|
||||
controlnets (`List[HunyuanDiT2DControlNetModel]`):
|
||||
controlnets (`list[HunyuanDiT2DControlNetModel]`):
|
||||
Provides additional conditioning to the unet during the denoising process. You must set multiple
|
||||
`HunyuanDiT2DControlNetModel` as a list.
|
||||
"""
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -39,7 +39,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@dataclass
|
||||
class QwenImageControlNetOutput(BaseOutput):
|
||||
controlnet_block_samples: Tuple[torch.Tensor]
|
||||
controlnet_block_samples: tuple[torch.Tensor]
|
||||
|
||||
|
||||
class QwenImageControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
|
||||
@@ -55,7 +55,7 @@ class QwenImageControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 24,
|
||||
joint_attention_dim: int = 3584,
|
||||
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
||||
axes_dims_rope: tuple[int, int, int] = (16, 56, 56),
|
||||
extra_condition_channels: int = 0, # for controlnet-inpainting
|
||||
):
|
||||
super().__init__()
|
||||
@@ -103,7 +103,7 @@ class QwenImageControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
@@ -188,11 +188,11 @@ class QwenImageControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
encoder_hidden_states_mask: torch.Tensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
img_shapes: Optional[List[Tuple[int, int, int]]] = None,
|
||||
txt_seq_lens: Optional[List[int]] = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
img_shapes: Optional[list[tuple[int, int, int]]] = None,
|
||||
txt_seq_lens: Optional[list[int]] = None,
|
||||
joint_attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
) -> torch.FloatTensor | Transformer2DModelOutput:
|
||||
"""
|
||||
The [`FluxTransformer2DModel`] forward method.
|
||||
|
||||
@@ -303,7 +303,7 @@ class QwenImageMultiControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, F
|
||||
to be compatible with `QwenImageControlNetModel`.
|
||||
|
||||
Args:
|
||||
controlnets (`List[QwenImageControlNetModel]`):
|
||||
controlnets (`list[QwenImageControlNetModel]`):
|
||||
Provides additional conditioning to the unet during the denoising process. You must set multiple
|
||||
`QwenImageControlNetModel` as a list.
|
||||
"""
|
||||
@@ -315,16 +315,16 @@ class QwenImageMultiControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, F
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
controlnet_cond: List[torch.tensor],
|
||||
conditioning_scale: List[float],
|
||||
controlnet_cond: list[torch.tensor],
|
||||
conditioning_scale: list[float],
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
encoder_hidden_states_mask: torch.Tensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
img_shapes: Optional[List[Tuple[int, int, int]]] = None,
|
||||
txt_seq_lens: Optional[List[int]] = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
img_shapes: Optional[list[tuple[int, int, int]]] = None,
|
||||
txt_seq_lens: Optional[list[int]] = None,
|
||||
joint_attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[QwenImageControlNetOutput, Tuple]:
|
||||
) -> QwenImageControlNetOutput | tuple:
|
||||
# ControlNet-Union with multiple conditions
|
||||
# only load one ControlNet for saving memories
|
||||
if len(self.nets) == 1:
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -35,7 +35,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@dataclass
|
||||
class SanaControlNetOutput(BaseOutput):
|
||||
controlnet_block_samples: Tuple[torch.Tensor]
|
||||
controlnet_block_samples: tuple[torch.Tensor]
|
||||
|
||||
|
||||
class SanaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
@@ -119,7 +119,7 @@ class SanaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
def attn_processors(self) -> dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
@@ -128,7 +128,7 @@ class SanaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
@@ -143,7 +143,7 @@ class SanaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(self, processor: AttentionProcessor | dict[str, AttentionProcessor]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -186,9 +186,9 @@ class SanaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
conditioning_scale: float = 1.0,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
|
||||
) -> tuple[torch.Tensor, ...] | Transformer2DModelOutput:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -36,7 +36,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@dataclass
|
||||
class SD3ControlNetOutput(BaseOutput):
|
||||
controlnet_block_samples: Tuple[torch.Tensor]
|
||||
controlnet_block_samples: tuple[torch.Tensor]
|
||||
|
||||
|
||||
class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
@@ -69,7 +69,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
||||
The maximum latent height/width of positional embeddings.
|
||||
extra_conditioning_channels (`int`, defaults to `0`):
|
||||
The number of extra channels to use for conditioning for patch embedding.
|
||||
dual_attention_layers (`Tuple[int, ...]`, defaults to `()`):
|
||||
dual_attention_layers (`tuple[int, ...]`, defaults to `()`):
|
||||
The number of dual-stream transformer blocks to use.
|
||||
qk_norm (`str`, *optional*, defaults to `None`):
|
||||
The normalization to use for query and key in the attention layer. If `None`, no normalization is used.
|
||||
@@ -99,7 +99,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
||||
out_channels: int = 16,
|
||||
pos_embed_max_size: int = 96,
|
||||
extra_conditioning_channels: int = 0,
|
||||
dual_attention_layers: Tuple[int, ...] = (),
|
||||
dual_attention_layers: tuple[int, ...] = (),
|
||||
qk_norm: Optional[str] = None,
|
||||
pos_embed_type: Optional[str] = "sincos",
|
||||
use_pos_embed: bool = True,
|
||||
@@ -206,7 +206,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
def attn_processors(self) -> dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
@@ -215,7 +215,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
@@ -230,7 +230,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(self, processor: AttentionProcessor | dict[str, AttentionProcessor]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -337,9 +337,9 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
pooled_projections: torch.Tensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
joint_attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
) -> torch.Tensor | Transformer2DModelOutput:
|
||||
"""
|
||||
The [`SD3Transformer2DModel`] forward method.
|
||||
|
||||
@@ -460,7 +460,7 @@ class SD3MultiControlNetModel(ModelMixin):
|
||||
compatible with `SD3ControlNetModel`.
|
||||
|
||||
Args:
|
||||
controlnets (`List[SD3ControlNetModel]`):
|
||||
controlnets (`list[SD3ControlNetModel]`):
|
||||
Provides additional conditioning to the unet during the denoising process. You must set multiple
|
||||
`SD3ControlNetModel` as a list.
|
||||
"""
|
||||
@@ -472,14 +472,14 @@ class SD3MultiControlNetModel(ModelMixin):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
controlnet_cond: List[torch.tensor],
|
||||
conditioning_scale: List[float],
|
||||
controlnet_cond: list[torch.tensor],
|
||||
conditioning_scale: list[float],
|
||||
pooled_projections: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
joint_attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SD3ControlNetOutput, Tuple]:
|
||||
) -> SD3ControlNetOutput | tuple:
|
||||
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
|
||||
block_samples = controlnet(
|
||||
hidden_states=hidden_states,
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -55,7 +55,7 @@ class SparseControlNetOutput(BaseOutput):
|
||||
Output can be used to condition the original UNet's middle block activation.
|
||||
"""
|
||||
|
||||
down_block_res_samples: Tuple[torch.Tensor]
|
||||
down_block_res_samples: tuple[torch.Tensor]
|
||||
mid_block_res_sample: torch.Tensor
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ class SparseControlNetConditioningEmbedding(nn.Module):
|
||||
self,
|
||||
conditioning_embedding_channels: int,
|
||||
conditioning_channels: int = 3,
|
||||
block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
|
||||
block_out_channels: tuple[int, ...] = (16, 32, 96, 256),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -110,7 +110,7 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
The frequency shift to apply to the time embedding.
|
||||
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
||||
The tuple of downsample blocks to use.
|
||||
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
||||
only_cross_attention (`Union[bool, tuple[bool]]`, defaults to `False`):
|
||||
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
||||
The tuple of output channels for each block.
|
||||
layers_per_block (`int`, defaults to 2):
|
||||
@@ -128,28 +128,28 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
The epsilon to use for the normalization.
|
||||
cross_attention_dim (`int`, defaults to 1280):
|
||||
The dimension of the cross attention features.
|
||||
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
||||
transformer_layers_per_block (`int` or `tuple[int]`, *optional*, defaults to 1):
|
||||
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
||||
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
||||
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
||||
transformer_layers_per_mid_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
||||
transformer_layers_per_mid_block (`int` or `tuple[int]`, *optional*, defaults to 1):
|
||||
The number of transformer layers to use in each layer in the middle block.
|
||||
attention_head_dim (`int` or `Tuple[int]`, defaults to 8):
|
||||
attention_head_dim (`int` or `tuple[int]`, defaults to 8):
|
||||
The dimension of the attention heads.
|
||||
num_attention_heads (`int` or `Tuple[int]`, *optional*):
|
||||
num_attention_heads (`int` or `tuple[int]`, *optional*):
|
||||
The number of heads to use for multi-head attention.
|
||||
use_linear_projection (`bool`, defaults to `False`):
|
||||
upcast_attention (`bool`, defaults to `False`):
|
||||
resnet_time_scale_shift (`str`, defaults to `"default"`):
|
||||
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
|
||||
conditioning_embedding_out_channels (`Tuple[int]`, defaults to `(16, 32, 96, 256)`):
|
||||
conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`):
|
||||
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
||||
global_pool_conditions (`bool`, defaults to `False`):
|
||||
TODO(Patrick) - unused parameter
|
||||
controlnet_conditioning_channel_order (`str`, defaults to `rgb`):
|
||||
motion_max_seq_length (`int`, defaults to `32`):
|
||||
The maximum sequence length to use in the motion module.
|
||||
motion_num_attention_heads (`int` or `Tuple[int]`, defaults to `8`):
|
||||
motion_num_attention_heads (`int` or `tuple[int]`, defaults to `8`):
|
||||
The number of heads to use in each attention layer of the motion module.
|
||||
concat_conditioning_mask (`bool`, defaults to `True`):
|
||||
use_simplified_condition_embedding (`bool`, defaults to `True`):
|
||||
@@ -164,14 +164,14 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
conditioning_channels: int = 4,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
down_block_types: tuple[str, ...] = (
|
||||
"CrossAttnDownBlockMotion",
|
||||
"CrossAttnDownBlockMotion",
|
||||
"CrossAttnDownBlockMotion",
|
||||
"DownBlockMotion",
|
||||
),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
only_cross_attention: bool | tuple[bool] = False,
|
||||
block_out_channels: tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
@@ -179,15 +179,15 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 768,
|
||||
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = None,
|
||||
temporal_transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
transformer_layers_per_block: int | tuple[int, ...] = 1,
|
||||
transformer_layers_per_mid_block: Optional[int | tuple[int]] = None,
|
||||
temporal_transformer_layers_per_block: int | tuple[int, ...] = 1,
|
||||
attention_head_dim: int | tuple[int, ...] = 8,
|
||||
num_attention_heads: Optional[int | tuple[int, ...]] = None,
|
||||
use_linear_projection: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
conditioning_embedding_out_channels: Optional[tuple[int, ...]] = (16, 32, 96, 256),
|
||||
global_pool_conditions: bool = False,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
motion_max_seq_length: int = 32,
|
||||
@@ -389,7 +389,7 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
cls,
|
||||
unet: UNet2DConditionModel,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
conditioning_embedding_out_channels: Optional[tuple[int, ...]] = (16, 32, 96, 256),
|
||||
load_weights_from_unet: bool = True,
|
||||
conditioning_channels: int = 3,
|
||||
) -> "SparseControlNetModel":
|
||||
@@ -450,7 +450,7 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
def attn_processors(self) -> dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
@@ -459,7 +459,7 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
@@ -474,7 +474,7 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(self, processor: AttentionProcessor | dict[str, AttentionProcessor]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -525,7 +525,7 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
self.set_attn_processor(processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
|
||||
def set_attention_slice(self, slice_size: str | int | list[int]) -> None:
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
@@ -579,7 +579,7 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
# Recursively walk through all the children.
|
||||
# Any children which exposes the set_attention_slice method
|
||||
# gets the message
|
||||
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
||||
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: list[int]):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
module.set_attention_slice(slice_size.pop())
|
||||
|
||||
@@ -593,17 +593,17 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
timestep: torch.Tensor | float | int,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
controlnet_cond: torch.Tensor,
|
||||
conditioning_scale: float = 1.0,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
cross_attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
conditioning_mask: Optional[torch.Tensor] = None,
|
||||
guess_mode: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SparseControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
|
||||
) -> SparseControlNetOutput | tuple[tuple[torch.Tensor, ...], torch.Tensor]:
|
||||
"""
|
||||
The [`SparseControlNetModel`] forward method.
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -94,7 +94,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
The frequency shift to apply to the time embedding.
|
||||
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
||||
The tuple of downsample blocks to use.
|
||||
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
||||
only_cross_attention (`Union[bool, tuple[bool]]`, defaults to `False`):
|
||||
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
||||
The tuple of output channels for each block.
|
||||
layers_per_block (`int`, defaults to 2):
|
||||
@@ -112,7 +112,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
The epsilon to use for the normalization.
|
||||
cross_attention_dim (`int`, defaults to 1280):
|
||||
The dimension of the cross attention features.
|
||||
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
||||
transformer_layers_per_block (`int` or `tuple[int]`, *optional*, defaults to 1):
|
||||
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
||||
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
||||
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
||||
@@ -122,7 +122,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
||||
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
||||
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
||||
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
|
||||
attention_head_dim (`Union[int, tuple[int]]`, defaults to 8):
|
||||
The dimension of the attention heads.
|
||||
use_linear_projection (`bool`, defaults to `False`):
|
||||
class_embed_type (`str`, *optional*, defaults to `None`):
|
||||
@@ -156,14 +156,14 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
conditioning_channels: int = 3,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
down_block_types: tuple[str, ...] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
only_cross_attention: bool | tuple[bool] = False,
|
||||
block_out_channels: tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
@@ -171,11 +171,11 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 1280,
|
||||
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
transformer_layers_per_block: int | tuple[int, ...] = 1,
|
||||
encoder_hid_dim: Optional[int] = None,
|
||||
encoder_hid_dim_type: Optional[str] = None,
|
||||
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
attention_head_dim: int | tuple[int, ...] = 8,
|
||||
num_attention_heads: Optional[int | tuple[int, ...]] = None,
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
addition_embed_type: Optional[str] = None,
|
||||
@@ -185,7 +185,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
resnet_time_scale_shift: str = "default",
|
||||
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (48, 96, 192, 384),
|
||||
conditioning_embedding_out_channels: Optional[tuple[int, ...]] = (48, 96, 192, 384),
|
||||
global_pool_conditions: bool = False,
|
||||
addition_embed_type_num_heads: int = 64,
|
||||
num_control_type: int = 6,
|
||||
@@ -390,7 +390,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
cls,
|
||||
unet: UNet2DConditionModel,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
conditioning_embedding_out_channels: Optional[tuple[int, ...]] = (16, 32, 96, 256),
|
||||
load_weights_from_unet: bool = True,
|
||||
):
|
||||
r"""
|
||||
@@ -457,7 +457,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
def attn_processors(self) -> dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
@@ -466,7 +466,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
@@ -481,7 +481,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(self, processor: AttentionProcessor | dict[str, AttentionProcessor]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -532,7 +532,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
self.set_attn_processor(processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
|
||||
def set_attention_slice(self, slice_size: str | int | list[int]) -> None:
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
@@ -586,7 +586,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
# Recursively walk through all the children.
|
||||
# Any children which exposes the set_attention_slice method
|
||||
# gets the message
|
||||
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
||||
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: list[int]):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
module.set_attention_slice(slice_size.pop())
|
||||
|
||||
@@ -600,21 +600,21 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
timestep: torch.Tensor | float | int,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
controlnet_cond: List[torch.Tensor],
|
||||
controlnet_cond: list[torch.Tensor],
|
||||
control_type: torch.Tensor,
|
||||
control_type_idx: List[int],
|
||||
conditioning_scale: Union[float, List[float]] = 1.0,
|
||||
control_type_idx: list[int],
|
||||
conditioning_scale: float | list[float] = 1.0,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
added_cond_kwargs: Optional[dict[str, torch.Tensor]] = None,
|
||||
cross_attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
from_multi: bool = False,
|
||||
guess_mode: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
|
||||
) -> ControlNetOutput | tuple[tuple[torch.Tensor, ...], torch.Tensor]:
|
||||
"""
|
||||
The [`ControlNetUnionModel`] forward method.
|
||||
|
||||
@@ -625,12 +625,12 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
The number of timesteps to denoise an input.
|
||||
encoder_hidden_states (`torch.Tensor`):
|
||||
The encoder hidden states.
|
||||
controlnet_cond (`List[torch.Tensor]`):
|
||||
controlnet_cond (`list[torch.Tensor]`):
|
||||
The conditional input tensors.
|
||||
control_type (`torch.Tensor`):
|
||||
A tensor of shape `(batch, num_control_type)` with values `0` or `1` depending on whether the control
|
||||
type is used.
|
||||
control_type_idx (`List[int]`):
|
||||
control_type_idx (`list[int]`):
|
||||
The indices of `control_type`.
|
||||
conditioning_scale (`float`, defaults to `1.0`):
|
||||
The scale factor for ControlNet outputs.
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from math import gcd
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
@@ -109,7 +109,7 @@ def get_down_block_adapter(
|
||||
temb_channels: int,
|
||||
max_norm_num_groups: Optional[int] = 32,
|
||||
has_crossattn=True,
|
||||
transformer_layers_per_block: Optional[Union[int, Tuple[int]]] = 1,
|
||||
transformer_layers_per_block: Optional[int | tuple[int]] = 1,
|
||||
num_attention_heads: Optional[int] = 1,
|
||||
cross_attention_dim: Optional[int] = 1024,
|
||||
add_downsample: bool = True,
|
||||
@@ -230,7 +230,7 @@ def get_mid_block_adapter(
|
||||
def get_up_block_adapter(
|
||||
out_channels: int,
|
||||
prev_output_channel: int,
|
||||
ctrl_skip_channels: List[int],
|
||||
ctrl_skip_channels: list[int],
|
||||
):
|
||||
ctrl_to_base = []
|
||||
num_layers = 3 # only support sd + sdxl
|
||||
@@ -278,7 +278,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
|
||||
The tuple of downsample blocks to use.
|
||||
sample_size (`int`, defaults to 96):
|
||||
Height and width of input/output sample.
|
||||
transformer_layers_per_block (`Union[int, Tuple[int]]`, defaults to 1):
|
||||
transformer_layers_per_block (`Union[int, tuple[int]]`, defaults to 1):
|
||||
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
||||
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
||||
upcast_attention (`bool`, defaults to `True`):
|
||||
@@ -293,21 +293,21 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
|
||||
self,
|
||||
conditioning_channels: int = 3,
|
||||
conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
|
||||
conditioning_embedding_out_channels: tuple[int] = (16, 32, 96, 256),
|
||||
time_embedding_mix: float = 1.0,
|
||||
learn_time_embedding: bool = False,
|
||||
num_attention_heads: Union[int, Tuple[int]] = 4,
|
||||
block_out_channels: Tuple[int] = (4, 8, 16, 16),
|
||||
base_block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
num_attention_heads: int | tuple[int] = 4,
|
||||
block_out_channels: tuple[int] = (4, 8, 16, 16),
|
||||
base_block_out_channels: tuple[int] = (320, 640, 1280, 1280),
|
||||
cross_attention_dim: int = 1024,
|
||||
down_block_types: Tuple[str] = (
|
||||
down_block_types: tuple[str] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
sample_size: Optional[int] = 96,
|
||||
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
||||
transformer_layers_per_block: int | tuple[int] = 1,
|
||||
upcast_attention: bool = True,
|
||||
max_norm_num_groups: int = 32,
|
||||
use_linear_projection: bool = True,
|
||||
@@ -430,13 +430,13 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
|
||||
cls,
|
||||
unet: UNet2DConditionModel,
|
||||
size_ratio: Optional[float] = None,
|
||||
block_out_channels: Optional[List[int]] = None,
|
||||
num_attention_heads: Optional[List[int]] = None,
|
||||
block_out_channels: Optional[list[int]] = None,
|
||||
num_attention_heads: Optional[list[int]] = None,
|
||||
learn_time_embedding: bool = False,
|
||||
time_embedding_mix: int = 1.0,
|
||||
conditioning_channels: int = 3,
|
||||
conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
|
||||
conditioning_embedding_out_channels: tuple[int] = (16, 32, 96, 256),
|
||||
):
|
||||
r"""
|
||||
Instantiate a [`ControlNetXSAdapter`] from a [`UNet2DConditionModel`].
|
||||
@@ -447,9 +447,9 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
|
||||
size_ratio (float, *optional*, defaults to `None`):
|
||||
When given, block_out_channels is set to a fraction of the base model's block_out_channels. Either this
|
||||
or `block_out_channels` must be given.
|
||||
block_out_channels (`List[int]`, *optional*, defaults to `None`):
|
||||
block_out_channels (`list[int]`, *optional*, defaults to `None`):
|
||||
Down blocks output channels in control model. Either this or `size_ratio` must be given.
|
||||
num_attention_heads (`List[int]`, *optional*, defaults to `None`):
|
||||
num_attention_heads (`list[int]`, *optional*, defaults to `None`):
|
||||
The dimension of the attention heads. The naming seems a bit confusing and it is, see
|
||||
https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why.
|
||||
learn_time_embedding (`bool`, defaults to `False`):
|
||||
@@ -461,7 +461,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
|
||||
Number of channels of conditioning input (e.g. an image)
|
||||
conditioning_channel_order (`str`, defaults to `"rgb"`):
|
||||
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
|
||||
conditioning_embedding_out_channels (`Tuple[int]`, defaults to `(16, 32, 96, 256)`):
|
||||
conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`):
|
||||
The tuple of output channel for each block in the `controlnet_cond_embedding` layer.
|
||||
"""
|
||||
|
||||
@@ -529,18 +529,18 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
||||
self,
|
||||
# unet configs
|
||||
sample_size: Optional[int] = 96,
|
||||
down_block_types: Tuple[str] = (
|
||||
down_block_types: tuple[str] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
up_block_types: tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
||||
block_out_channels: tuple[int] = (320, 640, 1280, 1280),
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
cross_attention_dim: Union[int, Tuple[int]] = 1024,
|
||||
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
||||
num_attention_heads: Union[int, Tuple[int]] = 8,
|
||||
cross_attention_dim: int | tuple[int] = 1024,
|
||||
transformer_layers_per_block: int | tuple[int] = 1,
|
||||
num_attention_heads: int | tuple[int] = 8,
|
||||
addition_embed_type: Optional[str] = None,
|
||||
addition_time_embed_dim: Optional[int] = None,
|
||||
upcast_attention: bool = True,
|
||||
@@ -550,11 +550,11 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
||||
# additional controlnet configs
|
||||
time_embedding_mix: float = 1.0,
|
||||
ctrl_conditioning_channels: int = 3,
|
||||
ctrl_conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
|
||||
ctrl_conditioning_embedding_out_channels: tuple[int] = (16, 32, 96, 256),
|
||||
ctrl_conditioning_channel_order: str = "rgb",
|
||||
ctrl_learn_time_embedding: bool = False,
|
||||
ctrl_block_out_channels: Tuple[int] = (4, 8, 16, 16),
|
||||
ctrl_num_attention_heads: Union[int, Tuple[int]] = 4,
|
||||
ctrl_block_out_channels: tuple[int] = (4, 8, 16, 16),
|
||||
ctrl_num_attention_heads: int | tuple[int] = 4,
|
||||
ctrl_max_norm_num_groups: int = 32,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -721,7 +721,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
||||
unet: UNet2DConditionModel,
|
||||
controlnet: Optional[ControlNetXSAdapter] = None,
|
||||
size_ratio: Optional[float] = None,
|
||||
ctrl_block_out_channels: Optional[List[float]] = None,
|
||||
ctrl_block_out_channels: Optional[list[float]] = None,
|
||||
time_embedding_mix: Optional[float] = None,
|
||||
ctrl_optional_kwargs: Optional[Dict] = None,
|
||||
):
|
||||
@@ -737,7 +737,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
||||
adapter will be created.
|
||||
size_ratio (float, *optional*, defaults to `None`):
|
||||
Used to construct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details.
|
||||
ctrl_block_out_channels (`List[int]`, *optional*, defaults to `None`):
|
||||
ctrl_block_out_channels (`list[int]`, *optional*, defaults to `None`):
|
||||
Used to construct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details,
|
||||
where this parameter is called `block_out_channels`.
|
||||
time_embedding_mix (`float`, *optional*, defaults to None):
|
||||
@@ -865,7 +865,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
def attn_processors(self) -> dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
@@ -874,7 +874,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
@@ -889,7 +889,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(self, processor: AttentionProcessor | dict[str, AttentionProcessor]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -1008,18 +1008,18 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
||||
def forward(
|
||||
self,
|
||||
sample: Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
timestep: torch.Tensor | float | int,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
controlnet_cond: Optional[torch.Tensor] = None,
|
||||
conditioning_scale: Optional[float] = 1.0,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
cross_attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
added_cond_kwargs: Optional[dict[str, torch.Tensor]] = None,
|
||||
return_dict: bool = True,
|
||||
apply_control: bool = True,
|
||||
) -> Union[ControlNetXSOutput, Tuple]:
|
||||
) -> ControlNetXSOutput | tuple:
|
||||
"""
|
||||
The [`ControlNetXSModel`] forward method.
|
||||
|
||||
@@ -1221,7 +1221,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
|
||||
norm_num_groups: int = 32,
|
||||
ctrl_max_norm_num_groups: int = 32,
|
||||
has_crossattn=True,
|
||||
transformer_layers_per_block: Optional[Union[int, Tuple[int]]] = 1,
|
||||
transformer_layers_per_block: Optional[int | tuple[int]] = 1,
|
||||
base_num_attention_heads: Optional[int] = 1,
|
||||
ctrl_num_attention_heads: Optional[int] = 1,
|
||||
cross_attention_dim: Optional[int] = 1024,
|
||||
@@ -1420,10 +1420,10 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
|
||||
hidden_states_ctrl: Optional[Tensor] = None,
|
||||
conditioning_scale: Optional[float] = 1.0,
|
||||
attention_mask: Optional[Tensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
cross_attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
encoder_attention_mask: Optional[Tensor] = None,
|
||||
apply_control: bool = True,
|
||||
) -> Tuple[Tensor, Tensor, Tuple[Tensor, ...], Tuple[Tensor, ...]]:
|
||||
) -> tuple[Tensor, Tensor, tuple[Tensor, ...], tuple[Tensor, ...]]:
|
||||
if cross_attention_kwargs is not None:
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
||||
@@ -1625,11 +1625,11 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module):
|
||||
encoder_hidden_states: Tensor,
|
||||
hidden_states_ctrl: Optional[Tensor] = None,
|
||||
conditioning_scale: Optional[float] = 1.0,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
cross_attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
attention_mask: Optional[Tensor] = None,
|
||||
encoder_attention_mask: Optional[Tensor] = None,
|
||||
apply_control: bool = True,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
if cross_attention_kwargs is not None:
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
||||
@@ -1661,7 +1661,7 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
prev_output_channel: int,
|
||||
ctrl_skip_channels: List[int],
|
||||
ctrl_skip_channels: list[int],
|
||||
temb_channels: int,
|
||||
norm_num_groups: int = 32,
|
||||
resolution_idx: Optional[int] = None,
|
||||
@@ -1806,12 +1806,12 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Tensor,
|
||||
res_hidden_states_tuple_base: Tuple[Tensor, ...],
|
||||
res_hidden_states_tuple_ctrl: Tuple[Tensor, ...],
|
||||
res_hidden_states_tuple_base: tuple[Tensor, ...],
|
||||
res_hidden_states_tuple_ctrl: tuple[Tensor, ...],
|
||||
temb: Tensor,
|
||||
encoder_hidden_states: Optional[Tensor] = None,
|
||||
conditioning_scale: Optional[float] = 1.0,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
cross_attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
attention_mask: Optional[Tensor] = None,
|
||||
upsample_size: Optional[int] = None,
|
||||
encoder_attention_mask: Optional[Tensor] = None,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -20,30 +20,30 @@ class MultiControlNetModel(ModelMixin):
|
||||
compatible with `ControlNetModel`.
|
||||
|
||||
Args:
|
||||
controlnets (`List[ControlNetModel]`):
|
||||
controlnets (`list[ControlNetModel]`):
|
||||
Provides additional conditioning to the unet during the denoising process. You must set multiple
|
||||
`ControlNetModel` as a list.
|
||||
"""
|
||||
|
||||
def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]):
|
||||
def __init__(self, controlnets: list[ControlNetModel] | tuple[ControlNetModel]):
|
||||
super().__init__()
|
||||
self.nets = nn.ModuleList(controlnets)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
timestep: torch.Tensor | float | int,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
controlnet_cond: List[torch.tensor],
|
||||
conditioning_scale: List[float],
|
||||
controlnet_cond: list[torch.tensor],
|
||||
conditioning_scale: list[float],
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
added_cond_kwargs: Optional[dict[str, torch.Tensor]] = None,
|
||||
cross_attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
guess_mode: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[ControlNetOutput, Tuple]:
|
||||
) -> ControlNetOutput | tuple:
|
||||
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
|
||||
down_samples, mid_sample = controlnet(
|
||||
sample=sample,
|
||||
@@ -74,7 +74,7 @@ class MultiControlNetModel(ModelMixin):
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
save_directory: str | os.PathLike,
|
||||
is_main_process: bool = True,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
@@ -111,7 +111,7 @@ class MultiControlNetModel(ModelMixin):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
def from_pretrained(cls, pretrained_model_path: Optional[str | os.PathLike], **kwargs):
|
||||
r"""
|
||||
Instantiate a pretrained MultiControlNet model from multiple pre-trained controlnet models.
|
||||
|
||||
@@ -134,7 +134,7 @@ class MultiControlNetModel(ModelMixin):
|
||||
Override the default `torch.dtype` and load the model under this dtype.
|
||||
output_loading_info(`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
device_map (`str` or `dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
A map that specifies where each submodule should go. It doesn't need to be refined to each
|
||||
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
|
||||
same device.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -21,32 +21,32 @@ class MultiControlNetUnionModel(ModelMixin):
|
||||
be compatible with `ControlNetUnionModel`.
|
||||
|
||||
Args:
|
||||
controlnets (`List[ControlNetUnionModel]`):
|
||||
controlnets (`list[ControlNetUnionModel]`):
|
||||
Provides additional conditioning to the unet during the denoising process. You must set multiple
|
||||
`ControlNetUnionModel` as a list.
|
||||
"""
|
||||
|
||||
def __init__(self, controlnets: Union[List[ControlNetUnionModel], Tuple[ControlNetUnionModel]]):
|
||||
def __init__(self, controlnets: list[ControlNetUnionModel] | tuple[ControlNetUnionModel]):
|
||||
super().__init__()
|
||||
self.nets = nn.ModuleList(controlnets)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
timestep: torch.Tensor | float | int,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
controlnet_cond: List[torch.tensor],
|
||||
control_type: List[torch.Tensor],
|
||||
control_type_idx: List[List[int]],
|
||||
conditioning_scale: List[float],
|
||||
controlnet_cond: list[torch.tensor],
|
||||
control_type: list[torch.Tensor],
|
||||
control_type_idx: list[list[int]],
|
||||
conditioning_scale: list[float],
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
added_cond_kwargs: Optional[dict[str, torch.Tensor]] = None,
|
||||
cross_attention_kwargs: Optional[dict[str, Any]] = None,
|
||||
guess_mode: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[ControlNetOutput, Tuple]:
|
||||
) -> ControlNetOutput | tuple:
|
||||
down_block_res_samples, mid_block_res_sample = None, None
|
||||
for i, (image, ctype, ctype_idx, scale, controlnet) in enumerate(
|
||||
zip(controlnet_cond, control_type, control_type_idx, conditioning_scale, self.nets)
|
||||
@@ -86,7 +86,7 @@ class MultiControlNetUnionModel(ModelMixin):
|
||||
# Copied from diffusers.models.controlnets.multicontrolnet.MultiControlNetModel.save_pretrained with ControlNet->ControlNetUnion
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
save_directory: str | os.PathLike,
|
||||
is_main_process: bool = True,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
@@ -124,7 +124,7 @@ class MultiControlNetUnionModel(ModelMixin):
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.models.controlnets.multicontrolnet.MultiControlNetModel.from_pretrained with ControlNet->ControlNetUnion
|
||||
def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
def from_pretrained(cls, pretrained_model_path: Optional[str | os.PathLike], **kwargs):
|
||||
r"""
|
||||
Instantiate a pretrained MultiControlNetUnion model from multiple pre-trained controlnet models.
|
||||
|
||||
@@ -147,7 +147,7 @@ class MultiControlNetUnionModel(ModelMixin):
|
||||
Override the default `torch.dtype` and load the model under this dtype.
|
||||
output_loading_info(`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
device_map (`str` or `dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
A map that specifies where each submodule should go. It doesn't need to be refined to each
|
||||
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
|
||||
same device.
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -168,7 +168,7 @@ class FirDownsample2D(nn.Module):
|
||||
channels: Optional[int] = None,
|
||||
out_channels: Optional[int] = None,
|
||||
use_conv: bool = False,
|
||||
fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
|
||||
fir_kernel: tuple[int, int, int, int] = (1, 3, 3, 1),
|
||||
):
|
||||
super().__init__()
|
||||
out_channels = out_channels if out_channels else channels
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -80,7 +80,7 @@ def get_timestep_embedding(
|
||||
|
||||
def get_3d_sincos_pos_embed(
|
||||
embed_dim: int,
|
||||
spatial_size: Union[int, Tuple[int, int]],
|
||||
spatial_size: int | tuple[int, int],
|
||||
temporal_size: int,
|
||||
spatial_interpolation_scale: float = 1.0,
|
||||
temporal_interpolation_scale: float = 1.0,
|
||||
@@ -93,7 +93,7 @@ def get_3d_sincos_pos_embed(
|
||||
Args:
|
||||
embed_dim (`int`):
|
||||
The embedding dimension of inputs. It must be divisible by 16.
|
||||
spatial_size (`int` or `Tuple[int, int]`):
|
||||
spatial_size (`int` or `tuple[int, int]`):
|
||||
The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
|
||||
spatial dimensions (height and width).
|
||||
temporal_size (`int`):
|
||||
@@ -154,7 +154,7 @@ def get_3d_sincos_pos_embed(
|
||||
|
||||
def _get_3d_sincos_pos_embed_np(
|
||||
embed_dim: int,
|
||||
spatial_size: Union[int, Tuple[int, int]],
|
||||
spatial_size: int | tuple[int, int],
|
||||
temporal_size: int,
|
||||
spatial_interpolation_scale: float = 1.0,
|
||||
temporal_interpolation_scale: float = 1.0,
|
||||
@@ -165,7 +165,7 @@ def _get_3d_sincos_pos_embed_np(
|
||||
Args:
|
||||
embed_dim (`int`):
|
||||
The embedding dimension of inputs. It must be divisible by 16.
|
||||
spatial_size (`int` or `Tuple[int, int]`):
|
||||
spatial_size (`int` or `tuple[int, int]`):
|
||||
The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
|
||||
spatial dimensions (height and width).
|
||||
temporal_size (`int`):
|
||||
@@ -609,10 +609,10 @@ class LuminaPatchEmbed(nn.Module):
|
||||
Patchifies and embeds the input tensor(s).
|
||||
|
||||
Args:
|
||||
x (List[torch.Tensor] | torch.Tensor): The input tensor(s) to be patchified and embedded.
|
||||
x (list[torch.Tensor] | torch.Tensor): The input tensor(s) to be patchified and embedded.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], torch.Tensor]: A tuple containing the patchified
|
||||
tuple[torch.Tensor, torch.Tensor, list[tuple[int, int]], torch.Tensor]: A tuple containing the patchified
|
||||
and embedded tensor(s), the mask indicating the valid patches, the original image size(s), and the
|
||||
frequency tensor(s).
|
||||
"""
|
||||
@@ -836,18 +836,18 @@ def get_3d_rotary_pos_embed(
|
||||
theta: int = 10000,
|
||||
use_real: bool = True,
|
||||
grid_type: str = "linspace",
|
||||
max_size: Optional[Tuple[int, int]] = None,
|
||||
max_size: Optional[tuple[int, int]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
RoPE for video tokens with 3D structure.
|
||||
|
||||
Args:
|
||||
embed_dim: (`int`):
|
||||
The embedding dimension size, corresponding to hidden_size_head.
|
||||
crops_coords (`Tuple[int]`):
|
||||
crops_coords (`tuple[int]`):
|
||||
The top-left and bottom-right coordinates of the crop.
|
||||
grid_size (`Tuple[int]`):
|
||||
grid_size (`tuple[int]`):
|
||||
The grid size of the spatial positional embedding (height, width).
|
||||
temporal_size (`int`):
|
||||
The size of the temporal dimension.
|
||||
@@ -934,10 +934,10 @@ def get_3d_rotary_pos_embed_allegro(
|
||||
crops_coords,
|
||||
grid_size,
|
||||
temporal_size,
|
||||
interpolation_scale: Tuple[float, float, float] = (1.0, 1.0, 1.0),
|
||||
interpolation_scale: tuple[float, float, float] = (1.0, 1.0, 1.0),
|
||||
theta: int = 10000,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
# TODO(aryan): docs
|
||||
start, stop = crops_coords
|
||||
grid_size_h, grid_size_w = grid_size
|
||||
@@ -981,9 +981,9 @@ def get_2d_rotary_pos_embed(
|
||||
Args:
|
||||
embed_dim: (`int`):
|
||||
The embedding dimension size
|
||||
crops_coords (`Tuple[int]`)
|
||||
crops_coords (`tuple[int]`)
|
||||
The top-left and bottom-right coordinates of the crop.
|
||||
grid_size (`Tuple[int]`):
|
||||
grid_size (`tuple[int]`):
|
||||
The grid size of the positional embedding.
|
||||
use_real (`bool`):
|
||||
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
||||
@@ -1029,9 +1029,9 @@ def _get_2d_rotary_pos_embed_np(embed_dim, crops_coords, grid_size, use_real=Tru
|
||||
Args:
|
||||
embed_dim: (`int`):
|
||||
The embedding dimension size
|
||||
crops_coords (`Tuple[int]`)
|
||||
crops_coords (`tuple[int]`)
|
||||
The top-left and bottom-right coordinates of the crop.
|
||||
grid_size (`Tuple[int]`):
|
||||
grid_size (`tuple[int]`):
|
||||
The grid size of the positional embedding.
|
||||
use_real (`bool`):
|
||||
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
||||
@@ -1119,7 +1119,7 @@ def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, n
|
||||
|
||||
def get_1d_rotary_pos_embed(
|
||||
dim: int,
|
||||
pos: Union[np.ndarray, int],
|
||||
pos: np.ndarray | int,
|
||||
theta: float = 10000.0,
|
||||
use_real=False,
|
||||
linear_factor=1.0,
|
||||
@@ -1186,11 +1186,11 @@ def get_1d_rotary_pos_embed(
|
||||
|
||||
def apply_rotary_emb(
|
||||
x: torch.Tensor,
|
||||
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
||||
freqs_cis: torch.Tensor | tuple[torch.Tensor],
|
||||
use_real: bool = True,
|
||||
use_real_unbind_dim: int = -1,
|
||||
sequence_dim: int = 2,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
||||
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
||||
@@ -1200,10 +1200,10 @@ def apply_rotary_emb(
|
||||
Args:
|
||||
x (`torch.Tensor`):
|
||||
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
||||
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
||||
freqs_cis (`tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
||||
tuple[torch.Tensor, torch.Tensor]: tuple of modified query tensor and key tensor with rotary embeddings.
|
||||
"""
|
||||
if use_real:
|
||||
cos, sin = freqs_cis # [S, D]
|
||||
@@ -2543,7 +2543,7 @@ class IPAdapterTimeImageProjection(nn.Module):
|
||||
self.time_proj = Timesteps(timestep_in_dim, timestep_flip_sin_to_cos, timestep_freq_shift)
|
||||
self.time_embedding = TimestepEmbedding(timestep_in_dim, hidden_dim, act_fn="silu")
|
||||
|
||||
def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
@@ -2552,7 +2552,7 @@ class IPAdapterTimeImageProjection(nn.Module):
|
||||
timestep (`torch.Tensor`):
|
||||
Timestep in denoising process.
|
||||
Returns:
|
||||
`Tuple`[`torch.Tensor`, `torch.Tensor`]: The pair (latents, timestep_emb).
|
||||
`tuple`[`torch.Tensor`, `torch.Tensor`]: The pair (latents, timestep_emb).
|
||||
"""
|
||||
timestep_emb = self.time_proj(timestep).to(dtype=x.dtype)
|
||||
timestep_emb = self.time_embedding(timestep_emb)
|
||||
@@ -2572,7 +2572,7 @@ class IPAdapterTimeImageProjection(nn.Module):
|
||||
|
||||
|
||||
class MultiIPAdapterImageProjection(nn.Module):
|
||||
def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]):
|
||||
def __init__(self, IPAdapterImageProjectionLayers: list[nn.Module] | tuple[nn.Module]):
|
||||
super().__init__()
|
||||
self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers)
|
||||
|
||||
@@ -2581,7 +2581,7 @@ class MultiIPAdapterImageProjection(nn.Module):
|
||||
"""Number of IP-Adapters loaded."""
|
||||
return len(self.image_projection_layers)
|
||||
|
||||
def forward(self, image_embeds: List[torch.Tensor]):
|
||||
def forward(self, image_embeds: list[torch.Tensor]):
|
||||
projected_image_embeds = []
|
||||
|
||||
# currently, we accept `image_embeds` as
|
||||
|
||||
@@ -21,7 +21,7 @@
|
||||
# ----------------------------------------------------------------#
|
||||
###################################################################
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -199,7 +199,7 @@ class LoRALinearLayer(nn.Module):
|
||||
out_features: int,
|
||||
rank: int = 4,
|
||||
network_alpha: Optional[float] = None,
|
||||
device: Optional[Union[torch.device, str]] = None,
|
||||
device: Optional[torch.device | str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -260,9 +260,9 @@ class LoRAConv2dLayer(nn.Module):
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
rank: int = 4,
|
||||
kernel_size: Union[int, Tuple[int, int]] = (1, 1),
|
||||
stride: Union[int, Tuple[int, int]] = (1, 1),
|
||||
padding: Union[int, Tuple[int, int], str] = 0,
|
||||
kernel_size: int | tuple[int, int] = (1, 1),
|
||||
stride: int | tuple[int, int] = (1, 1),
|
||||
padding: int | tuple[int, int] | str = 0,
|
||||
network_alpha: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -22,7 +22,7 @@ from array import array
|
||||
from collections import OrderedDict, defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, Optional
|
||||
from zipfile import is_zipfile
|
||||
|
||||
import safetensors
|
||||
@@ -135,7 +135,7 @@ def _fetch_remapped_cls_from_config(config, old_class):
|
||||
return old_class
|
||||
|
||||
|
||||
def _determine_param_device(param_name: str, device_map: Optional[Dict[str, Union[int, str, torch.device]]]):
|
||||
def _determine_param_device(param_name: str, device_map: Optional[dict[str, int | str | torch.device]]):
|
||||
"""
|
||||
Find the device of param_name from the device_map.
|
||||
"""
|
||||
@@ -153,10 +153,10 @@ def _determine_param_device(param_name: str, device_map: Optional[Dict[str, Unio
|
||||
|
||||
|
||||
def load_state_dict(
|
||||
checkpoint_file: Union[str, os.PathLike],
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
|
||||
checkpoint_file: str | os.PathLike,
|
||||
dduf_entries: Optional[dict[str, DDUFEntry]] = None,
|
||||
disable_mmap: bool = False,
|
||||
map_location: Union[str, torch.device] = "cpu",
|
||||
map_location: str | torch.device = "cpu",
|
||||
):
|
||||
"""
|
||||
Reads a checkpoint file, returning properly formatted errors if they arise.
|
||||
@@ -213,17 +213,17 @@ def load_state_dict(
|
||||
def load_model_dict_into_meta(
|
||||
model,
|
||||
state_dict: OrderedDict,
|
||||
dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
dtype: Optional[str | torch.dtype] = None,
|
||||
model_name_or_path: Optional[str] = None,
|
||||
hf_quantizer: Optional[DiffusersQuantizer] = None,
|
||||
keep_in_fp32_modules: Optional[List] = None,
|
||||
device_map: Optional[Dict[str, Union[int, str, torch.device]]] = None,
|
||||
unexpected_keys: Optional[List[str]] = None,
|
||||
offload_folder: Optional[Union[str, os.PathLike]] = None,
|
||||
keep_in_fp32_modules: Optional[list] = None,
|
||||
device_map: Optional[dict[str, int | str | torch.device]] = None,
|
||||
unexpected_keys: Optional[list[str]] = None,
|
||||
offload_folder: Optional[str | os.PathLike] = None,
|
||||
offload_index: Optional[Dict] = None,
|
||||
state_dict_index: Optional[Dict] = None,
|
||||
state_dict_folder: Optional[Union[str, os.PathLike]] = None,
|
||||
) -> List[str]:
|
||||
state_dict_folder: Optional[str | os.PathLike] = None,
|
||||
) -> list[str]:
|
||||
"""
|
||||
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
|
||||
params on a `meta` device. It replaces the model params with the data from the `state_dict`
|
||||
@@ -466,7 +466,7 @@ def _find_mismatched_keys(
|
||||
|
||||
def _load_state_dict_into_model(
|
||||
model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
# Convert old format to new format if needed from a PyTorch state_dict
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
state_dict = state_dict.copy()
|
||||
@@ -505,7 +505,7 @@ def _fetch_index_file(
|
||||
revision,
|
||||
user_agent,
|
||||
commit_hash,
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
|
||||
dduf_entries: Optional[dict[str, DDUFEntry]] = None,
|
||||
):
|
||||
if is_local:
|
||||
index_file = Path(
|
||||
@@ -555,7 +555,7 @@ def _fetch_index_file_legacy(
|
||||
revision,
|
||||
user_agent,
|
||||
commit_hash,
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
|
||||
dduf_entries: Optional[dict[str, DDUFEntry]] = None,
|
||||
):
|
||||
if is_local:
|
||||
index_file = Path(
|
||||
@@ -714,7 +714,7 @@ def _expand_device_map(device_map, param_names):
|
||||
|
||||
# Adapted from: https://github.com/huggingface/transformers/blob/0687d481e2c71544501ef9cb3eef795a6e79b1de/src/transformers/modeling_utils.py#L5859
|
||||
def _caching_allocator_warmup(
|
||||
model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype, hf_quantizer: Optional[DiffusersQuantizer]
|
||||
model, expanded_device_map: dict[str, torch.device], dtype: torch.dtype, hf_quantizer: Optional[DiffusersQuantizer]
|
||||
) -> None:
|
||||
"""
|
||||
This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
|
||||
import os
|
||||
from pickle import UnpicklingError
|
||||
from typing import Any, Dict, Union
|
||||
from typing import Any, Dict
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
@@ -68,7 +68,7 @@ class FlaxModelMixin(PushToHubMixin):
|
||||
"""
|
||||
return cls(config, **kwargs)
|
||||
|
||||
def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
|
||||
def _cast_floating_to(self, params: Dict | FrozenDict, dtype: jnp.dtype, mask: Any = None) -> Any:
|
||||
"""
|
||||
Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
|
||||
"""
|
||||
@@ -92,7 +92,7 @@ class FlaxModelMixin(PushToHubMixin):
|
||||
|
||||
return unflatten_dict(flat_params)
|
||||
|
||||
def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None):
|
||||
def to_bf16(self, params: Dict | FrozenDict, mask: Any = None):
|
||||
r"""
|
||||
Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast
|
||||
the `params` in place.
|
||||
@@ -131,7 +131,7 @@ class FlaxModelMixin(PushToHubMixin):
|
||||
```"""
|
||||
return self._cast_floating_to(params, jnp.bfloat16, mask)
|
||||
|
||||
def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None):
|
||||
def to_fp32(self, params: Dict | FrozenDict, mask: Any = None):
|
||||
r"""
|
||||
Cast the floating-point `params` to `jax.numpy.float32`. This method can be used to explicitly convert the
|
||||
model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place.
|
||||
@@ -158,7 +158,7 @@ class FlaxModelMixin(PushToHubMixin):
|
||||
```"""
|
||||
return self._cast_floating_to(params, jnp.float32, mask)
|
||||
|
||||
def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
|
||||
def to_fp16(self, params: Dict | FrozenDict, mask: Any = None):
|
||||
r"""
|
||||
Cast the floating-point `params` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the
|
||||
`params` in place.
|
||||
@@ -204,7 +204,7 @@ class FlaxModelMixin(PushToHubMixin):
|
||||
@validate_hf_hub_args
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_name_or_path: Union[str, os.PathLike],
|
||||
pretrained_model_name_or_path: str | os.PathLike,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
*model_args,
|
||||
**kwargs,
|
||||
@@ -240,7 +240,7 @@ class FlaxModelMixin(PushToHubMixin):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
proxies (`dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only(`bool`, *optional*, defaults to `False`):
|
||||
@@ -493,8 +493,8 @@ class FlaxModelMixin(PushToHubMixin):
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
params: Union[Dict, FrozenDict],
|
||||
save_directory: str | os.PathLike,
|
||||
params: Dict | FrozenDict,
|
||||
is_main_process: bool = True,
|
||||
push_to_hub: bool = False,
|
||||
**kwargs,
|
||||
@@ -516,7 +516,7 @@ class FlaxModelMixin(PushToHubMixin):
|
||||
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
|
||||
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||||
namespace).
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
kwargs (`dict[str, Any]`, *optional*):
|
||||
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
if os.path.isfile(save_directory):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user