Compare commits

..

18 Commits

Author SHA1 Message Date
sayakpaul b71269675e Release: v0.35.2-patch 2025-10-15 09:23:57 +05:30
Vladimir Mandic 36059182f1 fix scale_shift_factor being on cpu for wan and ltx (#12347)
* wan fix scale_shift_factor being on cpu

* apply device cast to ltx transformer

* Apply style fixes

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-10-15 09:19:40 +05:30
Aishwarya Badlani 9169e81609 Fix PyTorch 2.3.1 compatibility: add version guard for torch.library.… (#12206)
* Fix PyTorch 2.3.1 compatibility: add version guard for torch.library.custom_op

- Add hasattr() check for torch.library.custom_op and register_fake
- These functions were added in PyTorch 2.4, causing import failures in 2.3.1
- Both decorators and functions are now properly guarded with version checks
- Maintains backward compatibility while preserving functionality

Fixes #12195

* Use dummy decorators approach for PyTorch version compatibility

- Replace hasattr check with version string comparison
- Add no-op decorator functions for PyTorch < 2.4.0
- Follows pattern from #11941 as suggested by reviewer
- Maintains cleaner code structure without indentation changes

* Update src/diffusers/models/attention_dispatch.py

Update all the decorator usages

Co-authored-by: Aryan <contact.aryanvs@gmail.com>

* Update src/diffusers/models/attention_dispatch.py

Co-authored-by: Aryan <contact.aryanvs@gmail.com>

* Update src/diffusers/models/attention_dispatch.py

Co-authored-by: Aryan <contact.aryanvs@gmail.com>

* Update src/diffusers/models/attention_dispatch.py

Co-authored-by: Aryan <contact.aryanvs@gmail.com>

* Move version check to top of file and use private naming as requested

* Apply style fixes

---------

Co-authored-by: Aryan <contact.aryanvs@gmail.com>
Co-authored-by: Aryan <aryan@huggingface.co>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-10-15 09:14:51 +05:30
Dhruv Nair 8160289373 [CI] Fix TRANSFORMERS_FLAX_WEIGHTS_NAME import issue (#12354)
update
2025-10-15 09:09:25 +05:30
Sayak Paul 08782bf3bf handle offload_state_dict when initing transformers models (#12438) 2025-10-15 09:08:31 +05:30
sayakpaul 0f252be0ed Release: v0.35.1-patch 2025-08-20 09:42:00 +05:30
naykun e3d4a6b070 Performance Improve for Qwen Image Edit (#12190)
* fix(qwen-image-edit):
- update condition reshaping logic to improve editing performance

* fix(qwen-image-edit):
- remove _auto_resize
2025-08-20 09:40:28 +05:30
naykun ad00c565b7 Emergency fix for Qwen-Image-Edit (#12188)
fix(qwen-image):
shape calculation fix
2025-08-20 09:39:52 +05:30
sayakpaul f27949dad9 Release: v0.35.0 2025-08-19 08:34:27 +05:30
Linoy Tsaban 8d1de40891 [Wan 2.2 LoRA] add support for 2nd transformer lora loading + wan 2.2 lightx2v lora (#12074)
* add alpha

* load into 2nd transformer

* Update src/diffusers/loaders/lora_conversion_utils.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/loaders/lora_conversion_utils.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* pr comments

* pr comments

* pr comments

* fix

* fix

* Apply style fixes

* fix copies

* fix

* fix copies

* Update src/diffusers/loaders/lora_pipeline.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* revert change

* revert change

* fix copies

* up

* fix

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: linoy <linoy@hf.co>
2025-08-19 08:32:39 +05:30
Sayak Paul 8cc528c5e7 [chore] add lora button to qwenimage docs (#12183)
up
2025-08-19 07:13:24 +05:30
Taechai 3c50f0cdad Update README.md (#12182)
* Update README.md

Specify the full dir

* Update examples/dreambooth/README.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2025-08-18 13:02:49 -07:00
Sayak Paul 555b6cc34f [LoRA] feat: support more Qwen LoRAs from the community. (#12170)
* feat: support more Qwen LoRAs from the community.

* revert unrelated changes.

* Revert "revert unrelated changes."

This reverts commit 82dea555dc.
2025-08-18 20:56:28 +05:30
Sayak Paul 5b53f67f06 [docs] Clarify guidance scale in Qwen pipelines (#12181)
* add clarification regarding guidance_scale in QwenImage

* propagate.
2025-08-18 20:10:23 +05:30
MQY 9918d13eba fix(training_utils): wrap device in list for DiffusionPipeline (#12178)
- Modify offload_models function to handle DiffusionPipeline correctly
- Ensure compatibility with both single and multiple module inputs

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-08-18 13:56:17 +05:30
Sayak Paul e824660436 fix: caching allocator behaviour for quantization. (#12172)
* fix: caching allocator behaviour for quantization.

* up

* Update src/diffusers/models/model_loading_utils.py

Co-authored-by: Aryan <aryan@huggingface.co>

---------

Co-authored-by: Aryan <aryan@huggingface.co>
2025-08-18 13:16:18 +05:30
Leo Jiang 03be15e890 [Docs] typo error in qwen image (#12144)
typo error in qwen image

Co-authored-by: J石页 <jiangshuo9@h-partners.com>
Co-authored-by: Aryan <aryan@huggingface.co>
2025-08-18 11:55:42 +05:30
Junyu Chen 85cbe589a7 Minor modification to support DC-AE-turbo (#12169)
* minor modification to support dc-ae-turbo

* minor
2025-08-18 11:37:36 +05:30
74 changed files with 430 additions and 190 deletions
+10
View File
@@ -14,6 +14,10 @@
# QwenImage
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
</div>
Qwen-Image from the Qwen team is an image generation foundation model in the Qwen series that achieves significant advances in complex text rendering and precise image editing. Experiments show strong general capabilities in both image generation and editing, with exceptional performance in text rendering, especially for Chinese.
Qwen-Image comes in the following variants:
@@ -86,6 +90,12 @@ image.save("qwen_fewsteps.png")
</details>
<Tip>
The `guidance_scale` parameter in the pipeline is there to support future guidance-distilled models when they come up. Note that passing `guidance_scale` to the pipeline is ineffective. To enable classifier-free guidance, please pass `true_cfg_scale` and `negative_prompt` (even an empty negative prompt like " ") should enable classifier-free guidance computations.
</Tip>
## QwenImagePipeline
[[autodoc]] QwenImagePipeline
+2
View File
@@ -333,6 +333,8 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip
- Wan 2.1 and 2.2 support using [LightX2V LoRAs](https://huggingface.co/Kijai/WanVideo_comfy/tree/main/Lightx2v) to speed up inference. Using them on Wan 2.2 is slightly more involed. Refer to [this code snippet](https://github.com/huggingface/diffusers/pull/12040#issuecomment-3144185272) to learn more.
- Wan 2.2 has two denoisers. By default, LoRAs are only loaded into the first denoiser. One can set `load_into_transformer_2=True` to load LoRAs into the second denoiser. Refer to [this](https://github.com/huggingface/diffusers/pull/12074#issue-3292620048) and [this](https://github.com/huggingface/diffusers/pull/12074#issuecomment-3155896144) examples to learn more.
## WanPipeline
[[autodoc]] WanPipeline
@@ -90,7 +90,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__)
@@ -88,7 +88,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__)
@@ -95,7 +95,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__)
@@ -61,7 +61,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__)
+1 -1
View File
@@ -52,7 +52,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__)
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__)
@@ -43,8 +43,7 @@ from diffusers.utils import BaseOutput, check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
class MarigoldDepthOutput(BaseOutput):
"""
@@ -74,7 +74,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__)
@@ -67,7 +67,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__)
@@ -80,7 +80,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__)
@@ -73,7 +73,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__)
@@ -79,7 +79,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__)
+1 -1
View File
@@ -61,7 +61,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__)
+1 -1
View File
@@ -61,7 +61,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = logging.getLogger(__name__)
+1 -1
View File
@@ -66,7 +66,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__)
if is_torch_npu_available():
+1 -1
View File
@@ -62,7 +62,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__)
+1 -1
View File
@@ -62,7 +62,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__)
if is_torch_npu_available():
@@ -64,7 +64,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__)
+2 -1
View File
@@ -19,8 +19,9 @@ cd diffusers
pip install -e .
```
Then cd in the example folder and run
Install the requirements in the `examples/dreambooth` folder as shown below.
```bash
cd examples/dreambooth
pip install -r requirements.txt
```
+2 -2
View File
@@ -75,9 +75,9 @@ Now, we can launch training using:
```bash
export MODEL_NAME="Qwen/Qwen-Image"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-sana-lora"
export OUTPUT_DIR="trained-qwenimage-lora"
accelerate launch train_dreambooth_lora_sana.py \
accelerate launch train_dreambooth_lora_qwenimage.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
+1 -1
View File
@@ -64,7 +64,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__)
+1 -1
View File
@@ -35,7 +35,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
# Cache compiled models across invocations of this script.
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
+1 -1
View File
@@ -80,7 +80,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__)
+1 -1
View File
@@ -75,7 +75,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__)
@@ -87,7 +87,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__)
@@ -73,7 +73,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.34.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__)
@@ -75,7 +75,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__)
@@ -73,7 +73,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__)
@@ -75,7 +75,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__)
@@ -87,7 +87,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__)
@@ -73,7 +73,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__)
@@ -80,7 +80,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__)
+1 -1
View File
@@ -64,7 +64,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__)
+1 -1
View File
@@ -55,7 +55,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__)
@@ -58,7 +58,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__)
@@ -58,7 +58,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__, log_level="INFO")
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__, log_level="INFO")
@@ -53,7 +53,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__, log_level="INFO")
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__, log_level="INFO")
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__, log_level="INFO")
@@ -52,7 +52,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__, log_level="INFO")
@@ -61,7 +61,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__)
@@ -57,7 +57,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__, log_level="INFO")
@@ -49,7 +49,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = logging.getLogger(__name__)
@@ -56,7 +56,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__, log_level="INFO")
@@ -68,7 +68,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__)
if is_torch_npu_available():
@@ -55,7 +55,7 @@ from diffusers.utils.torch_utils import is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__)
if is_torch_npu_available():
@@ -82,7 +82,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__)
@@ -56,7 +56,7 @@ else:
# ------------------------------------------------------------------------------
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = logging.getLogger(__name__)
@@ -77,7 +77,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__)
@@ -29,7 +29,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__, log_level="INFO")
+1 -1
View File
@@ -50,7 +50,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.35.0.dev0")
check_min_version("0.35.0")
logger = get_logger(__name__, log_level="INFO")
+1 -1
View File
@@ -269,7 +269,7 @@ version_range_max = max(sys.version_info[1], 10) + 1
setup(
name="diffusers",
version="0.35.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="0.35.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="State-of-the-art diffusion in PyTorch and JAX.",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
+1 -1
View File
@@ -1,4 +1,4 @@
__version__ = "0.35.0.dev0"
__version__ = "0.35.2"
from typing import TYPE_CHECKING
+5 -1
View File
@@ -754,7 +754,11 @@ class LoraBaseMixin:
# Decompose weights into weights for denoiser and text encoders.
_component_adapter_weights = {}
for component in self._lora_loadable_modules:
model = getattr(self, component)
model = getattr(self, component, None)
# To guard for cases like Wan. In Wan2.1 and WanVace, we have a single denoiser.
# Whereas in Wan 2.2, we have two denoisers.
if model is None:
continue
for adapter_name, weights in zip(adapter_names, adapter_weights):
if isinstance(weights, dict):
+149 -32
View File
@@ -1833,6 +1833,17 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
k.startswith("time_projection") and k.endswith(".weight") for k in original_state_dict
)
def get_alpha_scales(down_weight, alpha_key):
rank = down_weight.shape[0]
alpha = original_state_dict.pop(alpha_key).item()
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
return scale_down, scale_up
for key in list(original_state_dict.keys()):
if key.endswith((".diff", ".diff_b")) and "norm" in key:
# NOTE: we don't support this because norm layer diff keys are just zeroed values. We can support it
@@ -1852,15 +1863,26 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
for i in range(min_block, max_block + 1):
# Self-attention
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
original_key = f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight"
converted_key = f"blocks.{i}.attn1.{c}.lora_A.weight"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
alpha_key = f"blocks.{i}.self_attn.{o}.alpha"
has_alpha = alpha_key in original_state_dict
original_key_A = f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight"
converted_key_A = f"blocks.{i}.attn1.{c}.lora_A.weight"
original_key = f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight"
converted_key = f"blocks.{i}.attn1.{c}.lora_B.weight"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
original_key_B = f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight"
converted_key_B = f"blocks.{i}.attn1.{c}.lora_B.weight"
if has_alpha:
down_weight = original_state_dict.pop(original_key_A)
up_weight = original_state_dict.pop(original_key_B)
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
converted_state_dict[converted_key_A] = down_weight * scale_down
converted_state_dict[converted_key_B] = up_weight * scale_up
else:
if original_key_A in original_state_dict:
converted_state_dict[converted_key_A] = original_state_dict.pop(original_key_A)
if original_key_B in original_state_dict:
converted_state_dict[converted_key_B] = original_state_dict.pop(original_key_B)
original_key = f"blocks.{i}.self_attn.{o}.diff_b"
converted_key = f"blocks.{i}.attn1.{c}.lora_B.bias"
@@ -1869,15 +1891,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
# Cross-attention
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
alpha_key = f"blocks.{i}.cross_attn.{o}.alpha"
has_alpha = alpha_key in original_state_dict
original_key_A = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
converted_key_A = f"blocks.{i}.attn2.{c}.lora_A.weight"
original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
original_key_B = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
converted_key_B = f"blocks.{i}.attn2.{c}.lora_B.weight"
if original_key_A in original_state_dict:
down_weight = original_state_dict.pop(original_key_A)
converted_state_dict[converted_key_A] = down_weight
if original_key_B in original_state_dict:
up_weight = original_state_dict.pop(original_key_B)
converted_state_dict[converted_key_B] = up_weight
if has_alpha:
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
converted_state_dict[converted_key_A] *= scale_down
converted_state_dict[converted_key_B] *= scale_up
original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
@@ -1886,15 +1917,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
if is_i2v_lora:
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
alpha_key = f"blocks.{i}.cross_attn.{o}.alpha"
has_alpha = alpha_key in original_state_dict
original_key_A = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
converted_key_A = f"blocks.{i}.attn2.{c}.lora_A.weight"
original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
original_key_B = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
converted_key_B = f"blocks.{i}.attn2.{c}.lora_B.weight"
if original_key_A in original_state_dict:
down_weight = original_state_dict.pop(original_key_A)
converted_state_dict[converted_key_A] = down_weight
if original_key_B in original_state_dict:
up_weight = original_state_dict.pop(original_key_B)
converted_state_dict[converted_key_B] = up_weight
if has_alpha:
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
converted_state_dict[converted_key_A] *= scale_down
converted_state_dict[converted_key_B] *= scale_up
original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
@@ -1903,15 +1943,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
# FFN
for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
original_key = f"blocks.{i}.{o}.{lora_down_key}.weight"
converted_key = f"blocks.{i}.ffn.{c}.lora_A.weight"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
alpha_key = f"blocks.{i}.{o}.alpha"
has_alpha = alpha_key in original_state_dict
original_key_A = f"blocks.{i}.{o}.{lora_down_key}.weight"
converted_key_A = f"blocks.{i}.ffn.{c}.lora_A.weight"
original_key = f"blocks.{i}.{o}.{lora_up_key}.weight"
converted_key = f"blocks.{i}.ffn.{c}.lora_B.weight"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
original_key_B = f"blocks.{i}.{o}.{lora_up_key}.weight"
converted_key_B = f"blocks.{i}.ffn.{c}.lora_B.weight"
if original_key_A in original_state_dict:
down_weight = original_state_dict.pop(original_key_A)
converted_state_dict[converted_key_A] = down_weight
if original_key_B in original_state_dict:
up_weight = original_state_dict.pop(original_key_B)
converted_state_dict[converted_key_B] = up_weight
if has_alpha:
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
converted_state_dict[converted_key_A] *= scale_down
converted_state_dict[converted_key_B] *= scale_up
original_key = f"blocks.{i}.{o}.diff_b"
converted_key = f"blocks.{i}.ffn.{c}.lora_B.bias"
@@ -2080,6 +2129,74 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref
def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
if has_lora_unet:
state_dict = {k.removeprefix("lora_unet_"): v for k, v in state_dict.items()}
def convert_key(key: str) -> str:
prefix = "transformer_blocks"
if "." in key:
base, suffix = key.rsplit(".", 1)
else:
base, suffix = key, ""
start = f"{prefix}_"
rest = base[len(start) :]
if "." in rest:
head, tail = rest.split(".", 1)
tail = "." + tail
else:
head, tail = rest, ""
# Protected n-grams that must keep their internal underscores
protected = {
# pairs
("to", "q"),
("to", "k"),
("to", "v"),
("to", "out"),
("add", "q"),
("add", "k"),
("add", "v"),
("txt", "mlp"),
("img", "mlp"),
("txt", "mod"),
("img", "mod"),
# triplets
("add", "q", "proj"),
("add", "k", "proj"),
("add", "v", "proj"),
("to", "add", "out"),
}
prot_by_len = {}
for ng in protected:
prot_by_len.setdefault(len(ng), set()).add(ng)
parts = head.split("_")
merged = []
i = 0
lengths_desc = sorted(prot_by_len.keys(), reverse=True)
while i < len(parts):
matched = False
for L in lengths_desc:
if i + L <= len(parts) and tuple(parts[i : i + L]) in prot_by_len[L]:
merged.append("_".join(parts[i : i + L]))
i += L
matched = True
break
if not matched:
merged.append(parts[i])
i += 1
head_converted = ".".join(merged)
converted_base = f"{prefix}.{head_converted}{tail}"
return converted_base + (("." + suffix) if suffix else "")
state_dict = {convert_key(k): v for k, v in state_dict.items()}
converted_state_dict = {}
all_keys = list(state_dict.keys())
down_key = ".lora_down.weight"
+61 -20
View File
@@ -5065,7 +5065,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`].
"""
_lora_loadable_modules = ["transformer"]
_lora_loadable_modules = ["transformer", "transformer_2"]
transformer_name = TRANSFORMER_NAME
@classmethod
@@ -5270,15 +5270,35 @@ class WanLoraLoaderMixin(LoraBaseMixin):
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
if load_into_transformer_2:
if not hasattr(self, "transformer_2"):
raise AttributeError(
f"'{type(self).__name__}' object has no attribute transformer_2"
"Note that Wan2.1 models do not have a transformer_2 component."
"Ensure the model has a transformer_2 component before setting load_into_transformer_2=True."
)
self.load_lora_into_transformer(
state_dict,
transformer=self.transformer_2,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
else:
self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name)
if not hasattr(self, "transformer")
else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel
@@ -5668,15 +5688,35 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
if load_into_transformer_2:
if not hasattr(self, "transformer_2"):
raise AttributeError(
f"'{type(self).__name__}' object has no attribute transformer_2"
"Note that Wan2.1 models do not have a transformer_2 component."
"Ensure the model has a transformer_2 component before setting load_into_transformer_2=True."
)
self.load_lora_into_transformer(
state_dict,
transformer=self.transformer_2,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
else:
self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name)
if not hasattr(self, "transformer")
else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SkyReelsV2Transformer3DModel
@@ -6643,7 +6683,8 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict)
if has_alphas_in_sd:
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
if has_alphas_in_sd or has_lora_unet:
state_dict = _convert_non_diffusers_qwen_lora_to_diffusers(state_dict)
out = (state_dict, metadata) if return_lora_metadata else state_dict
+25 -5
View File
@@ -110,6 +110,27 @@ if _CAN_USE_XFORMERS_ATTN:
else:
xops = None
# Version guard for PyTorch compatibility - custom_op was added in PyTorch 2.4
if torch.__version__ >= "2.4.0":
_custom_op = torch.library.custom_op
_register_fake = torch.library.register_fake
else:
def custom_op_no_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None):
def wrap(func):
return func
return wrap if fn is None else fn
def register_fake_no_op(op, fn=None, /, *, lib=None, _stacklevel=1):
def wrap(func):
return func
return wrap if fn is None else fn
_custom_op = custom_op_no_op
_register_fake = register_fake_no_op
logger = get_logger(__name__) # pylint: disable=invalid-name
@@ -473,12 +494,11 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
# ===== torch op registrations =====
# Registrations are required for fullgraph tracing compatibility
# TODO: library.custom_op and register_fake probably need version guards?
# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding
# this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590
@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
@_custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
def _wrapped_flash_attn_3_original(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -487,7 +507,7 @@ def _wrapped_flash_attn_3_original(
return out, lse
@torch.library.register_fake("flash_attn_3::_flash_attn_forward")
@_register_fake("flash_attn_3::_flash_attn_forward")
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, seq_len, num_heads, head_dim = query.shape
lse_shape = (batch_size, seq_len, num_heads)
@@ -299,6 +299,7 @@ class Decoder(nn.Module):
act_fn: Union[str, Tuple[str]] = "silu",
upsample_block_type: str = "pixel_shuffle",
in_shortcut: bool = True,
conv_act_fn: str = "relu",
):
super().__init__()
@@ -349,7 +350,7 @@ class Decoder(nn.Module):
channels = block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1]
self.norm_out = RMSNorm(channels, 1e-5, elementwise_affine=True, bias=True)
self.conv_act = nn.ReLU()
self.conv_act = get_activation(conv_act_fn)
self.conv_out = None
if layers_per_block[0] > 0:
@@ -414,6 +415,12 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
The normalization type(s) to use in the decoder.
decoder_act_fns (`Union[str, Tuple[str]]`, defaults to `"silu"`):
The activation function(s) to use in the decoder.
encoder_out_shortcut (`bool`, defaults to `True`):
Whether to use shortcut at the end of the encoder.
decoder_in_shortcut (`bool`, defaults to `True`):
Whether to use shortcut at the beginning of the decoder.
decoder_conv_act_fn (`str`, defaults to `"relu"`):
The activation function to use at the end of the decoder.
scaling_factor (`float`, defaults to `1.0`):
The multiplicative inverse of the root mean square of the latent features. This is used to scale the latent
space to have unit variance when training the diffusion model. The latents are scaled with the formula `z =
@@ -441,6 +448,9 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
downsample_block_type: str = "pixel_unshuffle",
decoder_norm_types: Union[str, Tuple[str]] = "rms_norm",
decoder_act_fns: Union[str, Tuple[str]] = "silu",
encoder_out_shortcut: bool = True,
decoder_in_shortcut: bool = True,
decoder_conv_act_fn: str = "relu",
scaling_factor: float = 1.0,
) -> None:
super().__init__()
@@ -454,6 +464,7 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
layers_per_block=encoder_layers_per_block,
qkv_multiscales=encoder_qkv_multiscales,
downsample_block_type=downsample_block_type,
out_shortcut=encoder_out_shortcut,
)
self.decoder = Decoder(
in_channels=in_channels,
@@ -466,6 +477,8 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
norm_type=decoder_norm_types,
act_fn=decoder_act_fns,
upsample_block_type=upsample_block_type,
in_shortcut=decoder_in_shortcut,
conv_act_fn=decoder_conv_act_fn,
)
self.spatial_compression_ratio = 2 ** (len(encoder_block_out_channels) - 1)
+15 -9
View File
@@ -726,23 +726,29 @@ def _caching_allocator_warmup(
very large margin.
"""
factor = 2 if hf_quantizer is None else hf_quantizer.get_cuda_warm_up_factor()
# Remove disk and cpu devices, and cast to proper torch.device
# Keep only accelerator devices
accelerator_device_map = {
param: torch.device(device)
for param, device in expanded_device_map.items()
if str(device) not in ["cpu", "disk"]
}
total_byte_count = defaultdict(lambda: 0)
if not accelerator_device_map:
return
elements_per_device = defaultdict(int)
for param_name, device in accelerator_device_map.items():
try:
param = model.get_parameter(param_name)
p = model.get_parameter(param_name)
except AttributeError:
param = model.get_buffer(param_name)
# The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
param_byte_count = param.numel() * param.element_size()
try:
p = model.get_buffer(param_name)
except AttributeError:
raise AttributeError(f"Parameter or buffer with name={param_name} not found in model")
# TODO: account for TP when needed.
total_byte_count[device] += param_byte_count
elements_per_device[device] += p.numel()
# This will kick off the caching allocator to avoid having to Malloc afterwards
for device, byte_count in total_byte_count.items():
_ = torch.empty(byte_count // factor, dtype=dtype, device=device, requires_grad=False)
for device, elem_count in elements_per_device.items():
warmup_elems = max(1, elem_count // factor)
_ = torch.empty(warmup_elems, dtype=dtype, device=device, requires_grad=False)
@@ -350,7 +350,9 @@ class LTXVideoTransformerBlock(nn.Module):
norm_hidden_states = self.norm1(hidden_states)
num_ada_params = self.scale_shift_table.shape[0]
ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)
ada_values = self.scale_shift_table[None, None].to(temb.device) + temb.reshape(
batch_size, temb.size(1), num_ada_params, -1
)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
@@ -665,12 +665,12 @@ class WanTransformer3DModel(
# 5. Output norm, projection & unpatchify
if temb.ndim == 3:
# batch_size, seq_len, inner_dim (wan 2.2 ti2v)
shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2)
shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2)
shift = shift.squeeze(2)
scale = scale.squeeze(2)
else:
# batch_size, inner_dim
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
# Move the shift and scale tensors to the same device as hidden_states.
# When using multi-GPU inference via accelerate these will be on the
@@ -103,7 +103,7 @@ class WanVACETransformerBlock(nn.Module):
control_hidden_states = control_hidden_states + hidden_states
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
self.scale_shift_table + temb.float()
self.scale_shift_table.to(temb.device) + temb.float()
).chunk(6, dim=1)
# 1. Self-attention
@@ -359,7 +359,7 @@ class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO
hidden_states = hidden_states + control_hint * scale
# 6. Output norm, projection & unpatchify
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
# Move the shift and scale tensors to the same device as hidden_states.
# When using multi-GPU inference via accelerate these will be on the
@@ -290,7 +290,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
trust_remote_code: bool = False,
trust_remote_code: Optional[bool] = None,
**kwargs,
):
hub_kwargs_names = [
@@ -48,10 +48,12 @@ from .transformers_loading_utils import _load_tokenizer_from_dduf, _load_transfo
if is_transformers_available():
import transformers
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
if is_transformers_version("<=", "4.56.2"):
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
if is_accelerate_available():
import accelerate
from accelerate import dispatch_model
@@ -112,7 +114,9 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
]
if is_transformers_available():
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
if is_transformers_version("<=", "4.56.2"):
weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME]
# model_pytorch, diffusion_model_pytorch, ...
weight_prefixes = [w.split(".")[0] for w in weight_names]
@@ -191,7 +195,9 @@ def filter_model_files(filenames):
]
if is_transformers_available():
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
if is_transformers_version("<=", "4.56.2"):
weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME]
allowed_extensions = [wn.split(".")[-1] for wn in weight_names]
@@ -212,7 +218,9 @@ def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -
]
if is_transformers_available():
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
if is_transformers_version("<=", "4.56.2"):
weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME]
# model_pytorch, diffusion_model_pytorch, ...
weight_prefixes = [w.split(".")[0] for w in weight_names]
@@ -830,6 +838,9 @@ def load_sub_model(
else:
loading_kwargs["low_cpu_mem_usage"] = False
if is_transformers_model and is_transformers_version(">=", "4.57.0"):
loading_kwargs.pop("offload_state_dict")
if (
quantization_config is not None
and isinstance(quantization_config, PipelineQuantizationConfig)
@@ -480,6 +480,11 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
the text `prompt`, usually at the expense of lower image quality.
This parameter in the pipeline is there to support future guidance-distilled models when they come up.
Note that passing `guidance_scale` to the pipeline is ineffective. To enable classifier-free guidance,
please pass `true_cfg_scale` and `negative_prompt` (even an empty negative prompt like " ") should
enable classifier-free guidance computations.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -62,25 +62,6 @@ EXAMPLE_DOC_STRING = """
>>> image.save("qwenimage_edit.png")
```
"""
PREFERRED_QWENIMAGE_RESOLUTIONS = [
(672, 1568),
(688, 1504),
(720, 1456),
(752, 1392),
(800, 1328),
(832, 1248),
(880, 1184),
(944, 1104),
(1024, 1024),
(1104, 944),
(1184, 880),
(1248, 832),
(1328, 800),
(1392, 752),
(1456, 720),
(1504, 688),
(1568, 672),
]
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
@@ -565,7 +546,6 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
_auto_resize: bool = True,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -597,6 +577,11 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
the text `prompt`, usually at the expense of lower image quality.
This parameter in the pipeline is there to support future guidance-distilled models when they come up.
Note that passing `guidance_scale` to the pipeline is ineffective. To enable classifier-free guidance,
please pass `true_cfg_scale` and `negative_prompt` (even an empty negative prompt like " ") should
enable classifier-free guidance computations.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -641,8 +626,7 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
returning a tuple, the first element is a list with the generated images.
"""
image_size = image[0].size if isinstance(image, list) else image.size
width, height = image_size
calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, width / height)
calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1])
height = height or calculated_height
width = width or calculated_width
@@ -680,18 +664,9 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
device = self._execution_device
# 3. Preprocess image
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
img = image[0] if isinstance(image, list) else image
image_height, image_width = self.image_processor.get_default_height_width(img)
aspect_ratio = image_width / image_height
if _auto_resize:
_, image_width, image_height = min(
(abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_QWENIMAGE_RESOLUTIONS
)
image_width = image_width // multiple_of * multiple_of
image_height = image_height // multiple_of * multiple_of
image = self.image_processor.resize(image, image_height, image_width)
image = self.image_processor.resize(image, calculated_height, calculated_width)
prompt_image = image
image = self.image_processor.preprocess(image, image_height, image_width)
image = self.image_processor.preprocess(image, calculated_height, calculated_width)
image = image.unsqueeze(2)
has_neg_prompt = negative_prompt is not None or (
@@ -708,9 +683,6 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
max_sequence_length=max_sequence_length,
)
if do_true_cfg:
# negative image is the same size as the original image, but all pixels are white
# negative_image = Image.new("RGB", (image.width, image.height), (255, 255, 255))
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
image=prompt_image,
prompt=negative_prompt,
@@ -737,7 +709,7 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
img_shapes = [
[
(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2),
(1, image_height // self.vae_scale_factor // 2, image_width // self.vae_scale_factor // 2),
(1, calculated_height // self.vae_scale_factor // 2, calculated_width // self.vae_scale_factor // 2),
]
] * batch_size
@@ -568,6 +568,11 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
the text `prompt`, usually at the expense of lower image quality.
This parameter in the pipeline is there to support future guidance-distilled models when they come up.
Note that passing `guidance_scale` to the pipeline is ineffective. To enable classifier-free guidance,
please pass `true_cfg_scale` and `negative_prompt` (even an empty negative prompt like " ") should
enable classifier-free guidance computations.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -698,6 +698,11 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
the text `prompt`, usually at the expense of lower image quality.
This parameter in the pipeline is there to support future guidance-distilled models when they come up.
Note that passing `guidance_scale` to the pipeline is ineffective. To enable classifier-free guidance,
please pass `true_cfg_scale` and `negative_prompt` (even an empty negative prompt like " ") should
enable classifier-free guidance computations.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+2 -1
View File
@@ -339,7 +339,8 @@ def offload_models(
original_devices = [next(m.parameters()).device for m in modules]
else:
assert len(modules) == 1
original_devices = modules[0].device
# For DiffusionPipeline, wrap the device in a list to make it iterable
original_devices = [modules[0].device]
# move to target device
for m in modules:
m.to(device)
-1
View File
@@ -45,7 +45,6 @@ DIFFUSERS_ATTN_BACKEND = os.getenv("DIFFUSERS_ATTN_BACKEND", "native")
DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES
DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8
HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES
DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").lower() in ENV_VARS_TRUE_VALUES
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
+44 -17
View File
@@ -20,6 +20,7 @@ import json
import os
import re
import shutil
import signal
import sys
import threading
from pathlib import Path
@@ -33,7 +34,6 @@ from packaging import version
from .. import __version__
from . import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
from .constants import DIFFUSERS_DISABLE_REMOTE_CODE
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -159,25 +159,52 @@ def check_imports(filename):
return get_relative_imports(filename)
def _raise_timeout_error(signum, frame):
raise ValueError(
"Loading this model requires you to execute custom code contained in the model repository on your local "
"machine. Please set the option `trust_remote_code=True` to permit loading of this model."
)
def resolve_trust_remote_code(trust_remote_code, model_name, has_remote_code):
trust_remote_code = trust_remote_code and not DIFFUSERS_DISABLE_REMOTE_CODE
if DIFFUSERS_DISABLE_REMOTE_CODE:
logger.warning(
"Downloading remote code is disabled globally via the DIFFUSERS_DISABLE_REMOTE_CODE environment variable. Ignoring `trust_remote_code`."
)
if trust_remote_code is None:
if has_remote_code and TIME_OUT_REMOTE_CODE > 0:
prev_sig_handler = None
try:
prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error)
signal.alarm(TIME_OUT_REMOTE_CODE)
while trust_remote_code is None:
answer = input(
f"The repository for {model_name} contains custom code which must be executed to correctly "
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n"
f"Do you wish to run the custom code? [y/N] "
)
if answer.lower() in ["yes", "y", "1"]:
trust_remote_code = True
elif answer.lower() in ["no", "n", "0", ""]:
trust_remote_code = False
signal.alarm(0)
except Exception:
# OS which does not support signal.SIGALRM
raise ValueError(
f"The repository for {model_name} contains custom code which must be executed to correctly "
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
)
finally:
if prev_sig_handler is not None:
signal.signal(signal.SIGALRM, prev_sig_handler)
signal.alarm(0)
elif has_remote_code:
# For the CI which puts the timeout at 0
_raise_timeout_error(None, None)
if has_remote_code and not trust_remote_code:
error_msg = f"The repository for {model_name} contains custom code. "
error_msg += (
"Downloading remote code is disabled globally via the DIFFUSERS_DISABLE_REMOTE_CODE environment variable."
if DIFFUSERS_DISABLE_REMOTE_CODE
else "Pass `trust_remote_code=True` to allow loading remote code modules."
)
raise ValueError(error_msg)
elif has_remote_code and trust_remote_code:
logger.warning(
f"`trust_remote_code` is enabled. Downloading code from {model_name}. Please ensure you trust the contents of this repository"
raise ValueError(
f"Loading {model_name} requires you to execute the configuration file in that"
" repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
" set the option `trust_remote_code=True` to remove this error."
)
return trust_remote_code