Compare commits
18 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b71269675e | |||
| 36059182f1 | |||
| 9169e81609 | |||
| 8160289373 | |||
| 08782bf3bf | |||
| 0f252be0ed | |||
| e3d4a6b070 | |||
| ad00c565b7 | |||
| f27949dad9 | |||
| 8d1de40891 | |||
| 8cc528c5e7 | |||
| 3c50f0cdad | |||
| 555b6cc34f | |||
| 5b53f67f06 | |||
| 9918d13eba | |||
| e824660436 | |||
| 03be15e890 | |||
| 85cbe589a7 |
@@ -14,6 +14,10 @@
|
||||
|
||||
# QwenImage
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
</div>
|
||||
|
||||
Qwen-Image from the Qwen team is an image generation foundation model in the Qwen series that achieves significant advances in complex text rendering and precise image editing. Experiments show strong general capabilities in both image generation and editing, with exceptional performance in text rendering, especially for Chinese.
|
||||
|
||||
Qwen-Image comes in the following variants:
|
||||
@@ -86,6 +90,12 @@ image.save("qwen_fewsteps.png")
|
||||
|
||||
</details>
|
||||
|
||||
<Tip>
|
||||
|
||||
The `guidance_scale` parameter in the pipeline is there to support future guidance-distilled models when they come up. Note that passing `guidance_scale` to the pipeline is ineffective. To enable classifier-free guidance, please pass `true_cfg_scale` and `negative_prompt` (even an empty negative prompt like " ") should enable classifier-free guidance computations.
|
||||
|
||||
</Tip>
|
||||
|
||||
## QwenImagePipeline
|
||||
|
||||
[[autodoc]] QwenImagePipeline
|
||||
|
||||
@@ -333,6 +333,8 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip
|
||||
|
||||
- Wan 2.1 and 2.2 support using [LightX2V LoRAs](https://huggingface.co/Kijai/WanVideo_comfy/tree/main/Lightx2v) to speed up inference. Using them on Wan 2.2 is slightly more involed. Refer to [this code snippet](https://github.com/huggingface/diffusers/pull/12040#issuecomment-3144185272) to learn more.
|
||||
|
||||
- Wan 2.2 has two denoisers. By default, LoRAs are only loaded into the first denoiser. One can set `load_into_transformer_2=True` to load LoRAs into the second denoiser. Refer to [this](https://github.com/huggingface/diffusers/pull/12074#issue-3292620048) and [this](https://github.com/huggingface/diffusers/pull/12074#issuecomment-3155896144) examples to learn more.
|
||||
|
||||
## WanPipeline
|
||||
|
||||
[[autodoc]] WanPipeline
|
||||
|
||||
@@ -90,7 +90,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -88,7 +88,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -95,7 +95,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -43,8 +43,7 @@ from diffusers.utils import BaseOutput, check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
check_min_version("0.35.0")
|
||||
|
||||
class MarigoldDepthOutput(BaseOutput):
|
||||
"""
|
||||
|
||||
@@ -74,7 +74,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -67,7 +67,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -80,7 +80,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -73,7 +73,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
|
||||
@@ -62,7 +62,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
|
||||
@@ -64,7 +64,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -19,8 +19,9 @@ cd diffusers
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Then cd in the example folder and run
|
||||
Install the requirements in the `examples/dreambooth` folder as shown below.
|
||||
```bash
|
||||
cd examples/dreambooth
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
|
||||
@@ -75,9 +75,9 @@ Now, we can launch training using:
|
||||
```bash
|
||||
export MODEL_NAME="Qwen/Qwen-Image"
|
||||
export INSTANCE_DIR="dog"
|
||||
export OUTPUT_DIR="trained-sana-lora"
|
||||
export OUTPUT_DIR="trained-qwenimage-lora"
|
||||
|
||||
accelerate launch train_dreambooth_lora_sana.py \
|
||||
accelerate launch train_dreambooth_lora_qwenimage.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
|
||||
@@ -64,7 +64,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
# Cache compiled models across invocations of this script.
|
||||
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
||||
|
||||
@@ -80,7 +80,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -75,7 +75,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -87,7 +87,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -73,7 +73,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -75,7 +75,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -73,7 +73,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -75,7 +75,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -87,7 +87,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -73,7 +73,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -80,7 +80,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
|
||||
@@ -55,7 +55,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
|
||||
@@ -82,7 +82,7 @@ else:
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ else:
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -77,7 +77,7 @@ else:
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.35.0.dev0")
|
||||
check_min_version("0.35.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -269,7 +269,7 @@ version_range_max = max(sys.version_info[1], 10) + 1
|
||||
|
||||
setup(
|
||||
name="diffusers",
|
||||
version="0.35.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="0.35.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
description="State-of-the-art diffusion in PyTorch and JAX.",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
__version__ = "0.35.0.dev0"
|
||||
__version__ = "0.35.2"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
@@ -754,7 +754,11 @@ class LoraBaseMixin:
|
||||
# Decompose weights into weights for denoiser and text encoders.
|
||||
_component_adapter_weights = {}
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component)
|
||||
model = getattr(self, component, None)
|
||||
# To guard for cases like Wan. In Wan2.1 and WanVace, we have a single denoiser.
|
||||
# Whereas in Wan 2.2, we have two denoisers.
|
||||
if model is None:
|
||||
continue
|
||||
|
||||
for adapter_name, weights in zip(adapter_names, adapter_weights):
|
||||
if isinstance(weights, dict):
|
||||
|
||||
@@ -1833,6 +1833,17 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
||||
k.startswith("time_projection") and k.endswith(".weight") for k in original_state_dict
|
||||
)
|
||||
|
||||
def get_alpha_scales(down_weight, alpha_key):
|
||||
rank = down_weight.shape[0]
|
||||
alpha = original_state_dict.pop(alpha_key).item()
|
||||
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
|
||||
scale_down = scale
|
||||
scale_up = 1.0
|
||||
while scale_down * 2 < scale_up:
|
||||
scale_down *= 2
|
||||
scale_up /= 2
|
||||
return scale_down, scale_up
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
if key.endswith((".diff", ".diff_b")) and "norm" in key:
|
||||
# NOTE: we don't support this because norm layer diff keys are just zeroed values. We can support it
|
||||
@@ -1852,15 +1863,26 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
||||
for i in range(min_block, max_block + 1):
|
||||
# Self-attention
|
||||
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
|
||||
original_key = f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight"
|
||||
converted_key = f"blocks.{i}.attn1.{c}.lora_A.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
alpha_key = f"blocks.{i}.self_attn.{o}.alpha"
|
||||
has_alpha = alpha_key in original_state_dict
|
||||
original_key_A = f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight"
|
||||
converted_key_A = f"blocks.{i}.attn1.{c}.lora_A.weight"
|
||||
|
||||
original_key = f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight"
|
||||
converted_key = f"blocks.{i}.attn1.{c}.lora_B.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
original_key_B = f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight"
|
||||
converted_key_B = f"blocks.{i}.attn1.{c}.lora_B.weight"
|
||||
|
||||
if has_alpha:
|
||||
down_weight = original_state_dict.pop(original_key_A)
|
||||
up_weight = original_state_dict.pop(original_key_B)
|
||||
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
|
||||
converted_state_dict[converted_key_A] = down_weight * scale_down
|
||||
converted_state_dict[converted_key_B] = up_weight * scale_up
|
||||
|
||||
else:
|
||||
if original_key_A in original_state_dict:
|
||||
converted_state_dict[converted_key_A] = original_state_dict.pop(original_key_A)
|
||||
if original_key_B in original_state_dict:
|
||||
converted_state_dict[converted_key_B] = original_state_dict.pop(original_key_B)
|
||||
|
||||
original_key = f"blocks.{i}.self_attn.{o}.diff_b"
|
||||
converted_key = f"blocks.{i}.attn1.{c}.lora_B.bias"
|
||||
@@ -1869,15 +1891,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
||||
|
||||
# Cross-attention
|
||||
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
|
||||
original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
|
||||
converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
alpha_key = f"blocks.{i}.cross_attn.{o}.alpha"
|
||||
has_alpha = alpha_key in original_state_dict
|
||||
original_key_A = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
|
||||
converted_key_A = f"blocks.{i}.attn2.{c}.lora_A.weight"
|
||||
|
||||
original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
|
||||
converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
original_key_B = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
|
||||
converted_key_B = f"blocks.{i}.attn2.{c}.lora_B.weight"
|
||||
|
||||
if original_key_A in original_state_dict:
|
||||
down_weight = original_state_dict.pop(original_key_A)
|
||||
converted_state_dict[converted_key_A] = down_weight
|
||||
if original_key_B in original_state_dict:
|
||||
up_weight = original_state_dict.pop(original_key_B)
|
||||
converted_state_dict[converted_key_B] = up_weight
|
||||
if has_alpha:
|
||||
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
|
||||
converted_state_dict[converted_key_A] *= scale_down
|
||||
converted_state_dict[converted_key_B] *= scale_up
|
||||
|
||||
original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
|
||||
converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
|
||||
@@ -1886,15 +1917,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
||||
|
||||
if is_i2v_lora:
|
||||
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
|
||||
original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
|
||||
converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
alpha_key = f"blocks.{i}.cross_attn.{o}.alpha"
|
||||
has_alpha = alpha_key in original_state_dict
|
||||
original_key_A = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
|
||||
converted_key_A = f"blocks.{i}.attn2.{c}.lora_A.weight"
|
||||
|
||||
original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
|
||||
converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
original_key_B = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
|
||||
converted_key_B = f"blocks.{i}.attn2.{c}.lora_B.weight"
|
||||
|
||||
if original_key_A in original_state_dict:
|
||||
down_weight = original_state_dict.pop(original_key_A)
|
||||
converted_state_dict[converted_key_A] = down_weight
|
||||
if original_key_B in original_state_dict:
|
||||
up_weight = original_state_dict.pop(original_key_B)
|
||||
converted_state_dict[converted_key_B] = up_weight
|
||||
if has_alpha:
|
||||
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
|
||||
converted_state_dict[converted_key_A] *= scale_down
|
||||
converted_state_dict[converted_key_B] *= scale_up
|
||||
|
||||
original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
|
||||
converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
|
||||
@@ -1903,15 +1943,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
||||
|
||||
# FFN
|
||||
for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
|
||||
original_key = f"blocks.{i}.{o}.{lora_down_key}.weight"
|
||||
converted_key = f"blocks.{i}.ffn.{c}.lora_A.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
alpha_key = f"blocks.{i}.{o}.alpha"
|
||||
has_alpha = alpha_key in original_state_dict
|
||||
original_key_A = f"blocks.{i}.{o}.{lora_down_key}.weight"
|
||||
converted_key_A = f"blocks.{i}.ffn.{c}.lora_A.weight"
|
||||
|
||||
original_key = f"blocks.{i}.{o}.{lora_up_key}.weight"
|
||||
converted_key = f"blocks.{i}.ffn.{c}.lora_B.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
original_key_B = f"blocks.{i}.{o}.{lora_up_key}.weight"
|
||||
converted_key_B = f"blocks.{i}.ffn.{c}.lora_B.weight"
|
||||
|
||||
if original_key_A in original_state_dict:
|
||||
down_weight = original_state_dict.pop(original_key_A)
|
||||
converted_state_dict[converted_key_A] = down_weight
|
||||
if original_key_B in original_state_dict:
|
||||
up_weight = original_state_dict.pop(original_key_B)
|
||||
converted_state_dict[converted_key_B] = up_weight
|
||||
if has_alpha:
|
||||
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
|
||||
converted_state_dict[converted_key_A] *= scale_down
|
||||
converted_state_dict[converted_key_B] *= scale_up
|
||||
|
||||
original_key = f"blocks.{i}.{o}.diff_b"
|
||||
converted_key = f"blocks.{i}.ffn.{c}.lora_B.bias"
|
||||
@@ -2080,6 +2129,74 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref
|
||||
|
||||
|
||||
def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
|
||||
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
|
||||
if has_lora_unet:
|
||||
state_dict = {k.removeprefix("lora_unet_"): v for k, v in state_dict.items()}
|
||||
|
||||
def convert_key(key: str) -> str:
|
||||
prefix = "transformer_blocks"
|
||||
if "." in key:
|
||||
base, suffix = key.rsplit(".", 1)
|
||||
else:
|
||||
base, suffix = key, ""
|
||||
|
||||
start = f"{prefix}_"
|
||||
rest = base[len(start) :]
|
||||
|
||||
if "." in rest:
|
||||
head, tail = rest.split(".", 1)
|
||||
tail = "." + tail
|
||||
else:
|
||||
head, tail = rest, ""
|
||||
|
||||
# Protected n-grams that must keep their internal underscores
|
||||
protected = {
|
||||
# pairs
|
||||
("to", "q"),
|
||||
("to", "k"),
|
||||
("to", "v"),
|
||||
("to", "out"),
|
||||
("add", "q"),
|
||||
("add", "k"),
|
||||
("add", "v"),
|
||||
("txt", "mlp"),
|
||||
("img", "mlp"),
|
||||
("txt", "mod"),
|
||||
("img", "mod"),
|
||||
# triplets
|
||||
("add", "q", "proj"),
|
||||
("add", "k", "proj"),
|
||||
("add", "v", "proj"),
|
||||
("to", "add", "out"),
|
||||
}
|
||||
|
||||
prot_by_len = {}
|
||||
for ng in protected:
|
||||
prot_by_len.setdefault(len(ng), set()).add(ng)
|
||||
|
||||
parts = head.split("_")
|
||||
merged = []
|
||||
i = 0
|
||||
lengths_desc = sorted(prot_by_len.keys(), reverse=True)
|
||||
|
||||
while i < len(parts):
|
||||
matched = False
|
||||
for L in lengths_desc:
|
||||
if i + L <= len(parts) and tuple(parts[i : i + L]) in prot_by_len[L]:
|
||||
merged.append("_".join(parts[i : i + L]))
|
||||
i += L
|
||||
matched = True
|
||||
break
|
||||
if not matched:
|
||||
merged.append(parts[i])
|
||||
i += 1
|
||||
|
||||
head_converted = ".".join(merged)
|
||||
converted_base = f"{prefix}.{head_converted}{tail}"
|
||||
return converted_base + (("." + suffix) if suffix else "")
|
||||
|
||||
state_dict = {convert_key(k): v for k, v in state_dict.items()}
|
||||
|
||||
converted_state_dict = {}
|
||||
all_keys = list(state_dict.keys())
|
||||
down_key = ".lora_down.weight"
|
||||
|
||||
@@ -5065,7 +5065,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`].
|
||||
"""
|
||||
|
||||
_lora_loadable_modules = ["transformer"]
|
||||
_lora_loadable_modules = ["transformer", "transformer_2"]
|
||||
transformer_name = TRANSFORMER_NAME
|
||||
|
||||
@classmethod
|
||||
@@ -5270,15 +5270,35 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
|
||||
if load_into_transformer_2:
|
||||
if not hasattr(self, "transformer_2"):
|
||||
raise AttributeError(
|
||||
f"'{type(self).__name__}' object has no attribute transformer_2"
|
||||
"Note that Wan2.1 models do not have a transformer_2 component."
|
||||
"Ensure the model has a transformer_2 component before setting load_into_transformer_2=True."
|
||||
)
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=self.transformer_2,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
else:
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=getattr(self, self.transformer_name)
|
||||
if not hasattr(self, "transformer")
|
||||
else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel
|
||||
@@ -5668,15 +5688,35 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
|
||||
if load_into_transformer_2:
|
||||
if not hasattr(self, "transformer_2"):
|
||||
raise AttributeError(
|
||||
f"'{type(self).__name__}' object has no attribute transformer_2"
|
||||
"Note that Wan2.1 models do not have a transformer_2 component."
|
||||
"Ensure the model has a transformer_2 component before setting load_into_transformer_2=True."
|
||||
)
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=self.transformer_2,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
else:
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=getattr(self, self.transformer_name)
|
||||
if not hasattr(self, "transformer")
|
||||
else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SkyReelsV2Transformer3DModel
|
||||
@@ -6643,7 +6683,8 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict)
|
||||
if has_alphas_in_sd:
|
||||
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
|
||||
if has_alphas_in_sd or has_lora_unet:
|
||||
state_dict = _convert_non_diffusers_qwen_lora_to_diffusers(state_dict)
|
||||
|
||||
out = (state_dict, metadata) if return_lora_metadata else state_dict
|
||||
|
||||
@@ -110,6 +110,27 @@ if _CAN_USE_XFORMERS_ATTN:
|
||||
else:
|
||||
xops = None
|
||||
|
||||
# Version guard for PyTorch compatibility - custom_op was added in PyTorch 2.4
|
||||
if torch.__version__ >= "2.4.0":
|
||||
_custom_op = torch.library.custom_op
|
||||
_register_fake = torch.library.register_fake
|
||||
else:
|
||||
|
||||
def custom_op_no_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None):
|
||||
def wrap(func):
|
||||
return func
|
||||
|
||||
return wrap if fn is None else fn
|
||||
|
||||
def register_fake_no_op(op, fn=None, /, *, lib=None, _stacklevel=1):
|
||||
def wrap(func):
|
||||
return func
|
||||
|
||||
return wrap if fn is None else fn
|
||||
|
||||
_custom_op = custom_op_no_op
|
||||
_register_fake = register_fake_no_op
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -473,12 +494,11 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
||||
|
||||
# ===== torch op registrations =====
|
||||
# Registrations are required for fullgraph tracing compatibility
|
||||
|
||||
|
||||
# TODO: library.custom_op and register_fake probably need version guards?
|
||||
# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding
|
||||
# this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590
|
||||
@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
|
||||
|
||||
|
||||
@_custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
|
||||
def _wrapped_flash_attn_3_original(
|
||||
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -487,7 +507,7 @@ def _wrapped_flash_attn_3_original(
|
||||
return out, lse
|
||||
|
||||
|
||||
@torch.library.register_fake("flash_attn_3::_flash_attn_forward")
|
||||
@_register_fake("flash_attn_3::_flash_attn_forward")
|
||||
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size, seq_len, num_heads, head_dim = query.shape
|
||||
lse_shape = (batch_size, seq_len, num_heads)
|
||||
|
||||
@@ -299,6 +299,7 @@ class Decoder(nn.Module):
|
||||
act_fn: Union[str, Tuple[str]] = "silu",
|
||||
upsample_block_type: str = "pixel_shuffle",
|
||||
in_shortcut: bool = True,
|
||||
conv_act_fn: str = "relu",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -349,7 +350,7 @@ class Decoder(nn.Module):
|
||||
channels = block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1]
|
||||
|
||||
self.norm_out = RMSNorm(channels, 1e-5, elementwise_affine=True, bias=True)
|
||||
self.conv_act = nn.ReLU()
|
||||
self.conv_act = get_activation(conv_act_fn)
|
||||
self.conv_out = None
|
||||
|
||||
if layers_per_block[0] > 0:
|
||||
@@ -414,6 +415,12 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
The normalization type(s) to use in the decoder.
|
||||
decoder_act_fns (`Union[str, Tuple[str]]`, defaults to `"silu"`):
|
||||
The activation function(s) to use in the decoder.
|
||||
encoder_out_shortcut (`bool`, defaults to `True`):
|
||||
Whether to use shortcut at the end of the encoder.
|
||||
decoder_in_shortcut (`bool`, defaults to `True`):
|
||||
Whether to use shortcut at the beginning of the decoder.
|
||||
decoder_conv_act_fn (`str`, defaults to `"relu"`):
|
||||
The activation function to use at the end of the decoder.
|
||||
scaling_factor (`float`, defaults to `1.0`):
|
||||
The multiplicative inverse of the root mean square of the latent features. This is used to scale the latent
|
||||
space to have unit variance when training the diffusion model. The latents are scaled with the formula `z =
|
||||
@@ -441,6 +448,9 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
downsample_block_type: str = "pixel_unshuffle",
|
||||
decoder_norm_types: Union[str, Tuple[str]] = "rms_norm",
|
||||
decoder_act_fns: Union[str, Tuple[str]] = "silu",
|
||||
encoder_out_shortcut: bool = True,
|
||||
decoder_in_shortcut: bool = True,
|
||||
decoder_conv_act_fn: str = "relu",
|
||||
scaling_factor: float = 1.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -454,6 +464,7 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
layers_per_block=encoder_layers_per_block,
|
||||
qkv_multiscales=encoder_qkv_multiscales,
|
||||
downsample_block_type=downsample_block_type,
|
||||
out_shortcut=encoder_out_shortcut,
|
||||
)
|
||||
self.decoder = Decoder(
|
||||
in_channels=in_channels,
|
||||
@@ -466,6 +477,8 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
norm_type=decoder_norm_types,
|
||||
act_fn=decoder_act_fns,
|
||||
upsample_block_type=upsample_block_type,
|
||||
in_shortcut=decoder_in_shortcut,
|
||||
conv_act_fn=decoder_conv_act_fn,
|
||||
)
|
||||
|
||||
self.spatial_compression_ratio = 2 ** (len(encoder_block_out_channels) - 1)
|
||||
|
||||
@@ -726,23 +726,29 @@ def _caching_allocator_warmup(
|
||||
very large margin.
|
||||
"""
|
||||
factor = 2 if hf_quantizer is None else hf_quantizer.get_cuda_warm_up_factor()
|
||||
# Remove disk and cpu devices, and cast to proper torch.device
|
||||
|
||||
# Keep only accelerator devices
|
||||
accelerator_device_map = {
|
||||
param: torch.device(device)
|
||||
for param, device in expanded_device_map.items()
|
||||
if str(device) not in ["cpu", "disk"]
|
||||
}
|
||||
total_byte_count = defaultdict(lambda: 0)
|
||||
if not accelerator_device_map:
|
||||
return
|
||||
|
||||
elements_per_device = defaultdict(int)
|
||||
for param_name, device in accelerator_device_map.items():
|
||||
try:
|
||||
param = model.get_parameter(param_name)
|
||||
p = model.get_parameter(param_name)
|
||||
except AttributeError:
|
||||
param = model.get_buffer(param_name)
|
||||
# The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
|
||||
param_byte_count = param.numel() * param.element_size()
|
||||
try:
|
||||
p = model.get_buffer(param_name)
|
||||
except AttributeError:
|
||||
raise AttributeError(f"Parameter or buffer with name={param_name} not found in model")
|
||||
# TODO: account for TP when needed.
|
||||
total_byte_count[device] += param_byte_count
|
||||
elements_per_device[device] += p.numel()
|
||||
|
||||
# This will kick off the caching allocator to avoid having to Malloc afterwards
|
||||
for device, byte_count in total_byte_count.items():
|
||||
_ = torch.empty(byte_count // factor, dtype=dtype, device=device, requires_grad=False)
|
||||
for device, elem_count in elements_per_device.items():
|
||||
warmup_elems = max(1, elem_count // factor)
|
||||
_ = torch.empty(warmup_elems, dtype=dtype, device=device, requires_grad=False)
|
||||
|
||||
@@ -350,7 +350,9 @@ class LTXVideoTransformerBlock(nn.Module):
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
|
||||
num_ada_params = self.scale_shift_table.shape[0]
|
||||
ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)
|
||||
ada_values = self.scale_shift_table[None, None].to(temb.device) + temb.reshape(
|
||||
batch_size, temb.size(1), num_ada_params, -1
|
||||
)
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
||||
|
||||
|
||||
@@ -665,12 +665,12 @@ class WanTransformer3DModel(
|
||||
# 5. Output norm, projection & unpatchify
|
||||
if temb.ndim == 3:
|
||||
# batch_size, seq_len, inner_dim (wan 2.2 ti2v)
|
||||
shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2)
|
||||
shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2)
|
||||
shift = shift.squeeze(2)
|
||||
scale = scale.squeeze(2)
|
||||
else:
|
||||
# batch_size, inner_dim
|
||||
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
|
||||
shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
|
||||
|
||||
# Move the shift and scale tensors to the same device as hidden_states.
|
||||
# When using multi-GPU inference via accelerate these will be on the
|
||||
|
||||
@@ -103,7 +103,7 @@ class WanVACETransformerBlock(nn.Module):
|
||||
control_hidden_states = control_hidden_states + hidden_states
|
||||
|
||||
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
|
||||
self.scale_shift_table + temb.float()
|
||||
self.scale_shift_table.to(temb.device) + temb.float()
|
||||
).chunk(6, dim=1)
|
||||
|
||||
# 1. Self-attention
|
||||
@@ -359,7 +359,7 @@ class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO
|
||||
hidden_states = hidden_states + control_hint * scale
|
||||
|
||||
# 6. Output norm, projection & unpatchify
|
||||
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
|
||||
shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
|
||||
|
||||
# Move the shift and scale tensors to the same device as hidden_states.
|
||||
# When using multi-GPU inference via accelerate these will be on the
|
||||
|
||||
@@ -290,7 +290,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_name_or_path: str,
|
||||
trust_remote_code: bool = False,
|
||||
trust_remote_code: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
hub_kwargs_names = [
|
||||
|
||||
@@ -48,10 +48,12 @@ from .transformers_loading_utils import _load_tokenizer_from_dduf, _load_transfo
|
||||
if is_transformers_available():
|
||||
import transformers
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
||||
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
|
||||
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
|
||||
|
||||
if is_transformers_version("<=", "4.56.2"):
|
||||
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
|
||||
|
||||
if is_accelerate_available():
|
||||
import accelerate
|
||||
from accelerate import dispatch_model
|
||||
@@ -112,7 +114,9 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
|
||||
]
|
||||
|
||||
if is_transformers_available():
|
||||
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
|
||||
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
|
||||
if is_transformers_version("<=", "4.56.2"):
|
||||
weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME]
|
||||
|
||||
# model_pytorch, diffusion_model_pytorch, ...
|
||||
weight_prefixes = [w.split(".")[0] for w in weight_names]
|
||||
@@ -191,7 +195,9 @@ def filter_model_files(filenames):
|
||||
]
|
||||
|
||||
if is_transformers_available():
|
||||
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
|
||||
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
|
||||
if is_transformers_version("<=", "4.56.2"):
|
||||
weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME]
|
||||
|
||||
allowed_extensions = [wn.split(".")[-1] for wn in weight_names]
|
||||
|
||||
@@ -212,7 +218,9 @@ def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -
|
||||
]
|
||||
|
||||
if is_transformers_available():
|
||||
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
|
||||
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
|
||||
if is_transformers_version("<=", "4.56.2"):
|
||||
weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME]
|
||||
|
||||
# model_pytorch, diffusion_model_pytorch, ...
|
||||
weight_prefixes = [w.split(".")[0] for w in weight_names]
|
||||
@@ -830,6 +838,9 @@ def load_sub_model(
|
||||
else:
|
||||
loading_kwargs["low_cpu_mem_usage"] = False
|
||||
|
||||
if is_transformers_model and is_transformers_version(">=", "4.57.0"):
|
||||
loading_kwargs.pop("offload_state_dict")
|
||||
|
||||
if (
|
||||
quantization_config is not None
|
||||
and isinstance(quantization_config, PipelineQuantizationConfig)
|
||||
|
||||
@@ -480,6 +480,11 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
||||
the text `prompt`, usually at the expense of lower image quality.
|
||||
|
||||
This parameter in the pipeline is there to support future guidance-distilled models when they come up.
|
||||
Note that passing `guidance_scale` to the pipeline is ineffective. To enable classifier-free guidance,
|
||||
please pass `true_cfg_scale` and `negative_prompt` (even an empty negative prompt like " ") should
|
||||
enable classifier-free guidance computations.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
|
||||
@@ -62,25 +62,6 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> image.save("qwenimage_edit.png")
|
||||
```
|
||||
"""
|
||||
PREFERRED_QWENIMAGE_RESOLUTIONS = [
|
||||
(672, 1568),
|
||||
(688, 1504),
|
||||
(720, 1456),
|
||||
(752, 1392),
|
||||
(800, 1328),
|
||||
(832, 1248),
|
||||
(880, 1184),
|
||||
(944, 1104),
|
||||
(1024, 1024),
|
||||
(1104, 944),
|
||||
(1184, 880),
|
||||
(1248, 832),
|
||||
(1328, 800),
|
||||
(1392, 752),
|
||||
(1456, 720),
|
||||
(1504, 688),
|
||||
(1568, 672),
|
||||
]
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
|
||||
@@ -565,7 +546,6 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
_auto_resize: bool = True,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -597,6 +577,11 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
||||
the text `prompt`, usually at the expense of lower image quality.
|
||||
|
||||
This parameter in the pipeline is there to support future guidance-distilled models when they come up.
|
||||
Note that passing `guidance_scale` to the pipeline is ineffective. To enable classifier-free guidance,
|
||||
please pass `true_cfg_scale` and `negative_prompt` (even an empty negative prompt like " ") should
|
||||
enable classifier-free guidance computations.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
@@ -641,8 +626,7 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
image_size = image[0].size if isinstance(image, list) else image.size
|
||||
width, height = image_size
|
||||
calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, width / height)
|
||||
calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1])
|
||||
height = height or calculated_height
|
||||
width = width or calculated_width
|
||||
|
||||
@@ -680,18 +664,9 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
device = self._execution_device
|
||||
# 3. Preprocess image
|
||||
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
|
||||
img = image[0] if isinstance(image, list) else image
|
||||
image_height, image_width = self.image_processor.get_default_height_width(img)
|
||||
aspect_ratio = image_width / image_height
|
||||
if _auto_resize:
|
||||
_, image_width, image_height = min(
|
||||
(abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_QWENIMAGE_RESOLUTIONS
|
||||
)
|
||||
image_width = image_width // multiple_of * multiple_of
|
||||
image_height = image_height // multiple_of * multiple_of
|
||||
image = self.image_processor.resize(image, image_height, image_width)
|
||||
image = self.image_processor.resize(image, calculated_height, calculated_width)
|
||||
prompt_image = image
|
||||
image = self.image_processor.preprocess(image, image_height, image_width)
|
||||
image = self.image_processor.preprocess(image, calculated_height, calculated_width)
|
||||
image = image.unsqueeze(2)
|
||||
|
||||
has_neg_prompt = negative_prompt is not None or (
|
||||
@@ -708,9 +683,6 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
if do_true_cfg:
|
||||
# negative image is the same size as the original image, but all pixels are white
|
||||
# negative_image = Image.new("RGB", (image.width, image.height), (255, 255, 255))
|
||||
|
||||
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
|
||||
image=prompt_image,
|
||||
prompt=negative_prompt,
|
||||
@@ -737,7 +709,7 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
img_shapes = [
|
||||
[
|
||||
(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2),
|
||||
(1, image_height // self.vae_scale_factor // 2, image_width // self.vae_scale_factor // 2),
|
||||
(1, calculated_height // self.vae_scale_factor // 2, calculated_width // self.vae_scale_factor // 2),
|
||||
]
|
||||
] * batch_size
|
||||
|
||||
|
||||
@@ -568,6 +568,11 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
||||
the text `prompt`, usually at the expense of lower image quality.
|
||||
|
||||
This parameter in the pipeline is there to support future guidance-distilled models when they come up.
|
||||
Note that passing `guidance_scale` to the pipeline is ineffective. To enable classifier-free guidance,
|
||||
please pass `true_cfg_scale` and `negative_prompt` (even an empty negative prompt like " ") should
|
||||
enable classifier-free guidance computations.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
|
||||
@@ -698,6 +698,11 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
||||
the text `prompt`, usually at the expense of lower image quality.
|
||||
|
||||
This parameter in the pipeline is there to support future guidance-distilled models when they come up.
|
||||
Note that passing `guidance_scale` to the pipeline is ineffective. To enable classifier-free guidance,
|
||||
please pass `true_cfg_scale` and `negative_prompt` (even an empty negative prompt like " ") should
|
||||
enable classifier-free guidance computations.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
|
||||
@@ -339,7 +339,8 @@ def offload_models(
|
||||
original_devices = [next(m.parameters()).device for m in modules]
|
||||
else:
|
||||
assert len(modules) == 1
|
||||
original_devices = modules[0].device
|
||||
# For DiffusionPipeline, wrap the device in a list to make it iterable
|
||||
original_devices = [modules[0].device]
|
||||
# move to target device
|
||||
for m in modules:
|
||||
m.to(device)
|
||||
|
||||
@@ -45,7 +45,6 @@ DIFFUSERS_ATTN_BACKEND = os.getenv("DIFFUSERS_ATTN_BACKEND", "native")
|
||||
DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES
|
||||
DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8
|
||||
HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES
|
||||
DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").lower() in ENV_VARS_TRUE_VALUES
|
||||
|
||||
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
|
||||
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
|
||||
|
||||
@@ -20,6 +20,7 @@ import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
from pathlib import Path
|
||||
@@ -33,7 +34,6 @@ from packaging import version
|
||||
|
||||
from .. import __version__
|
||||
from . import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
|
||||
from .constants import DIFFUSERS_DISABLE_REMOTE_CODE
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -159,25 +159,52 @@ def check_imports(filename):
|
||||
return get_relative_imports(filename)
|
||||
|
||||
|
||||
def _raise_timeout_error(signum, frame):
|
||||
raise ValueError(
|
||||
"Loading this model requires you to execute custom code contained in the model repository on your local "
|
||||
"machine. Please set the option `trust_remote_code=True` to permit loading of this model."
|
||||
)
|
||||
|
||||
|
||||
def resolve_trust_remote_code(trust_remote_code, model_name, has_remote_code):
|
||||
trust_remote_code = trust_remote_code and not DIFFUSERS_DISABLE_REMOTE_CODE
|
||||
if DIFFUSERS_DISABLE_REMOTE_CODE:
|
||||
logger.warning(
|
||||
"Downloading remote code is disabled globally via the DIFFUSERS_DISABLE_REMOTE_CODE environment variable. Ignoring `trust_remote_code`."
|
||||
)
|
||||
if trust_remote_code is None:
|
||||
if has_remote_code and TIME_OUT_REMOTE_CODE > 0:
|
||||
prev_sig_handler = None
|
||||
try:
|
||||
prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error)
|
||||
signal.alarm(TIME_OUT_REMOTE_CODE)
|
||||
while trust_remote_code is None:
|
||||
answer = input(
|
||||
f"The repository for {model_name} contains custom code which must be executed to correctly "
|
||||
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
|
||||
f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n"
|
||||
f"Do you wish to run the custom code? [y/N] "
|
||||
)
|
||||
if answer.lower() in ["yes", "y", "1"]:
|
||||
trust_remote_code = True
|
||||
elif answer.lower() in ["no", "n", "0", ""]:
|
||||
trust_remote_code = False
|
||||
signal.alarm(0)
|
||||
except Exception:
|
||||
# OS which does not support signal.SIGALRM
|
||||
raise ValueError(
|
||||
f"The repository for {model_name} contains custom code which must be executed to correctly "
|
||||
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
|
||||
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
|
||||
)
|
||||
finally:
|
||||
if prev_sig_handler is not None:
|
||||
signal.signal(signal.SIGALRM, prev_sig_handler)
|
||||
signal.alarm(0)
|
||||
elif has_remote_code:
|
||||
# For the CI which puts the timeout at 0
|
||||
_raise_timeout_error(None, None)
|
||||
|
||||
if has_remote_code and not trust_remote_code:
|
||||
error_msg = f"The repository for {model_name} contains custom code. "
|
||||
error_msg += (
|
||||
"Downloading remote code is disabled globally via the DIFFUSERS_DISABLE_REMOTE_CODE environment variable."
|
||||
if DIFFUSERS_DISABLE_REMOTE_CODE
|
||||
else "Pass `trust_remote_code=True` to allow loading remote code modules."
|
||||
)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
elif has_remote_code and trust_remote_code:
|
||||
logger.warning(
|
||||
f"`trust_remote_code` is enabled. Downloading code from {model_name}. Please ensure you trust the contents of this repository"
|
||||
raise ValueError(
|
||||
f"Loading {model_name} requires you to execute the configuration file in that"
|
||||
" repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
|
||||
" set the option `trust_remote_code=True` to remove this error."
|
||||
)
|
||||
|
||||
return trust_remote_code
|
||||
|
||||
Reference in New Issue
Block a user