Compare commits

..

17 Commits

Author SHA1 Message Date
Sayak Paul 12f65435cb Merge branch 'main' into complete-sentences-scripts 2025-08-07 09:58:56 +05:30
Dhruv Nair 5780776c8a Make prompt_2 optional in Flux Pipelines (#12073)
* update

* update
2025-08-06 15:40:12 -10:00
Aryan f19421e27c Helper functions to return skip-layer compatible layers (#12048)
update

Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>
2025-08-06 07:55:16 -10:00
Aryan 69cdc25746 Fix group offloading synchronization bug for parameter-only GroupModule's (#12077)
* update

* update

* refactor

* fuck yeah

* make style

* Update src/diffusers/hooks/group_offloading.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/hooks/group_offloading.py

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-08-06 21:11:00 +05:30
Aryan cfd6ec7465 [refactor] condense group offloading (#11990)
* update

* update

* refactor

* add test

* address review comment

* nit
2025-08-06 20:01:02 +05:30
jiqing-feng 1082c46afa fix input shape for WanGGUFTexttoVideoSingleFileTests (#12081)
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
2025-08-06 14:12:40 +05:30
Isotr0py ba2ba9019f Add cuda kernel support for GGUF inference (#11869)
* add gguf kernel support

Signed-off-by: Isotr0py <2037008807@qq.com>

* fix

Signed-off-by: Isotr0py <2037008807@qq.com>

* optimize

Signed-off-by: Isotr0py <2037008807@qq.com>

* update

* update

* update

* update

* update

---------

Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: DN6 <dhruv.nair@gmail.com>
2025-08-05 21:36:48 +05:30
C fa4c0e5e2e optimize QwenImagePipeline to reduce unnecessary CUDA synchronization (#12072) 2025-08-05 04:12:47 -10:00
Sayak Paul b793debd9d [tests] deal with the failing AudioLDM2 tests (#12069)
up
2025-08-05 15:54:25 +05:30
Aryan 377057126c [tests] Fix Qwen test_inference slices (#12070)
update
2025-08-05 14:10:22 +05:30
Sayak Paul 5937e11d85 [docs] small corrections to the example in the Qwen docs (#12068)
* up

* up
2025-08-05 09:47:21 +05:30
Sayak Paul 9c1d4e3be1 [wip] feat: support lora in qwen image and training script (#12056)
* feat: support lora in qwen image and training script

* up

* up

* up

* up

* up

* up

* add lora tests

* fix

* add tests

* fix

* reviewer feedback

* up[

* Apply suggestions from code review

Co-authored-by: Aryan <aryan@huggingface.co>

---------

Co-authored-by: Aryan <aryan@huggingface.co>
2025-08-05 07:06:02 +05:30
Steven Liu 7ea065c507 [docs] Install (#12026)
* initial

* init
2025-08-04 10:13:36 -07:00
Sayak Paul 7a7a487396 fix the rest for all GPUs in CI (#12064)
fix the rest
2025-08-04 21:03:33 +05:30
Sayak Paul 4efb4db9d0 enable all gpus when running ci. (#12062) 2025-08-04 20:17:34 +05:30
Sayak Paul 3e615b3f5b Merge branch 'main' into complete-sentences-scripts 2025-08-01 08:14:31 +05:30
sayakpaul 65c2da5f42 complete the licensing statement. 2025-07-28 11:33:35 +05:30
94 changed files with 3136 additions and 253 deletions
+4 -4
View File
@@ -178,7 +178,7 @@ jobs:
container:
image: diffusers/diffusers-pytorch-cuda
options: --gpus 0 --shm-size "16gb" --ipc host
options: --gpus all --shm-size "16gb" --ipc host
steps:
- name: Checkout diffusers
@@ -333,7 +333,7 @@ jobs:
additional_deps: ["peft"]
- backend: "gguf"
test_location: "gguf"
additional_deps: ["peft"]
additional_deps: ["peft", "kernels"]
- backend: "torchao"
test_location: "torchao"
additional_deps: []
@@ -344,7 +344,7 @@ jobs:
group: aws-g6e-xlarge-plus
container:
image: diffusers/diffusers-pytorch-cuda
options: --shm-size "20gb" --ipc host --gpus 0
options: --shm-size "20gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
@@ -396,7 +396,7 @@ jobs:
group: aws-g6e-xlarge-plus
container:
image: diffusers/diffusers-pytorch-cuda
options: --shm-size "20gb" --ipc host --gpus 0
options: --shm-size "20gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
+1 -1
View File
@@ -253,7 +253,7 @@ jobs:
container:
image: diffusers/diffusers-pytorch-cuda
options: --gpus 0 --shm-size "16gb" --ipc host
options: --gpus all --shm-size "16gb" --ipc host
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
+3 -3
View File
@@ -167,7 +167,7 @@ jobs:
container:
image: diffusers/diffusers-pytorch-cuda
options: --gpus 0 --shm-size "16gb" --ipc host
options: --gpus all --shm-size "16gb" --ipc host
steps:
- name: Checkout diffusers
@@ -210,7 +210,7 @@ jobs:
container:
image: diffusers/diffusers-pytorch-xformers-cuda
options: --gpus 0 --shm-size "16gb" --ipc host
options: --gpus all --shm-size "16gb" --ipc host
steps:
- name: Checkout diffusers
@@ -252,7 +252,7 @@ jobs:
container:
image: diffusers/diffusers-pytorch-cuda
options: --gpus 0 --shm-size "16gb" --ipc host
options: --gpus all --shm-size "16gb" --ipc host
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
+3 -3
View File
@@ -222,7 +222,7 @@ jobs:
container:
image: diffusers/diffusers-pytorch-cuda
options: --gpus 0 --shm-size "16gb" --ipc host
options: --gpus all --shm-size "16gb" --ipc host
steps:
- name: Checkout diffusers
@@ -265,7 +265,7 @@ jobs:
container:
image: diffusers/diffusers-pytorch-xformers-cuda
options: --gpus 0 --shm-size "16gb" --ipc host
options: --gpus all --shm-size "16gb" --ipc host
steps:
- name: Checkout diffusers
@@ -307,7 +307,7 @@ jobs:
container:
image: diffusers/diffusers-pytorch-cuda
options: --gpus 0 --shm-size "16gb" --ipc host
options: --gpus all --shm-size "16gb" --ipc host
steps:
- name: Checkout diffusers
+1 -1
View File
@@ -30,7 +30,7 @@ jobs:
group: aws-g4dn-2xlarge
container:
image: ${{ github.event.inputs.docker_image }}
options: --gpus 0 --privileged --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
options: --gpus all --privileged --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
steps:
- name: Validate test files input
+5
View File
@@ -30,6 +30,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
- [`CogView4LoraLoaderMixin`] provides similar functions for [CogView4](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogview4).
- [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`].
- [`HiDreamImageLoraLoaderMixin`] provides similar functions for [HiDream Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hidream)
- [`QwenImageLoraLoaderMixin`] provides similar functions for [Qwen Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwen)
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.
<Tip>
@@ -105,6 +106,10 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
[[autodoc]] loaders.lora_pipeline.HiDreamImageLoraLoaderMixin
## QwenImageLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.QwenImageLoraLoaderMixin
## LoraBaseMixin
[[autodoc]] loaders.lora_base.LoraBaseMixin
+4 -2
View File
@@ -14,7 +14,9 @@
# QwenImage
<!-- TODO: update this section when model is out -->
Qwen-Image from the Qwen team is an image generation foundation model in the Qwen series that achieves significant advances in complex text rendering and precise image editing. Experiments show strong general capabilities in both image generation and editing, with exceptional performance in text rendering, especially for Chinese.
Check out the model card [here](https://huggingface.co/Qwen/Qwen-Image) to learn more.
<Tip>
@@ -28,6 +30,6 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
- all
- __call__
## QwenImagePipeline
## QwenImagePipelineOutput
[[autodoc]] pipelines.qwenimage.pipeline_output.QwenImagePipelineOutput
+75 -102
View File
@@ -12,183 +12,156 @@ specific language governing permissions and limitations under the License.
# Installation
🤗 Diffusers is tested on Python 3.8+, PyTorch 1.7.0+, and Flax. Follow the installation instructions below for the deep learning library you are using:
Diffusers is tested on Python 3.8+, PyTorch 1.4+, and Flax 0.4.1+. Follow the installation instructions for the deep learning library you're using, [PyTorch](https://pytorch.org/get-started/locally/) or [Flax](https://flax.readthedocs.io/en/latest/).
- [PyTorch](https://pytorch.org/get-started/locally/) installation instructions
- [Flax](https://flax.readthedocs.io/en/latest/) installation instructions
## Install with pip
You should install 🤗 Diffusers in a [virtual environment](https://docs.python.org/3/library/venv.html).
If you're unfamiliar with Python virtual environments, take a look at this [guide](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
A virtual environment makes it easier to manage different projects and avoid compatibility issues between dependencies.
Create a virtual environment with Python or [uv](https://docs.astral.sh/uv/) (refer to [Installation](https://docs.astral.sh/uv/getting-started/installation/) for installation instructions), a fast Rust-based Python package and project manager.
<hfoptions id="install">
<hfoption id="uv">
Create a [virtual environment](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) for easier management of separate projects and to avoid compatibility issues between dependencies. Use [uv](https://docs.astral.sh/uv/), a Rust-based Python package and project manager, to create a virtual environment and install Diffusers.
```bash
uv venv my-env
source my-env/bin/activate
```
</hfoption>
<hfoption id="Python">
Install Diffusers with one of the following methods.
<hfoptions id="install">
<hfoption id="pip">
PyTorch only supports Python 3.8 - 3.11 on Windows.
```bash
python -m venv my-env
source my-env/bin/activate
uv pip install diffusers["torch"] transformers
```
</hfoption>
</hfoptions>
You should also install 🤗 Transformers because 🤗 Diffusers relies on its models.
<frameworkcontent>
<pt>
PyTorch only supports Python 3.8 - 3.11 on Windows. Install Diffusers with uv.
```bash
uv install diffusers["torch"] transformers
```
You can also install Diffusers with pip.
```bash
pip install diffusers["torch"] transformers
```
</pt>
<jax>
Install Diffusers with uv.
Use the command below for Flax.
```bash
uv pip install diffusers["flax"] transformers
```
You can also install Diffusers with pip.
```bash
pip install diffusers["flax"] transformers
```
</jax>
</frameworkcontent>
## Install with conda
After activating your virtual environment, with `conda` (maintained by the community):
</hfoption>
<hfoption id="conda">
```bash
conda install -c conda-forge diffusers
```
## Install from source
</hfoption>
<hfoption id="source">
Before installing 🤗 Diffusers from source, make sure you have PyTorch and 🤗 Accelerate installed.
A source install installs the `main` version instead of the latest `stable` version. The `main` version is useful for staying updated with the latest changes but it may not always be stable. If you run into a problem, open an [Issue](https://github.com/huggingface/diffusers/issues/new/choose) and we will try to resolve it as soon as possible.
To install 🤗 Accelerate:
Make sure [Accelerate](https://huggingface.co/docs/accelerate/index) is installed.
```bash
pip install accelerate
uv pip install accelerate
```
Then install 🤗 Diffusers from source:
Install Diffusers from source with the command below.
```bash
pip install git+https://github.com/huggingface/diffusers
uv pip install git+https://github.com/huggingface/diffusers
```
This command installs the bleeding edge `main` version rather than the latest `stable` version.
The `main` version is useful for staying up-to-date with the latest developments.
For instance, if a bug has been fixed since the last official release but a new release hasn't been rolled out yet.
However, this means the `main` version may not always be stable.
We strive to keep the `main` version operational, and most issues are usually resolved within a few hours or a day.
If you run into a problem, please open an [Issue](https://github.com/huggingface/diffusers/issues/new/choose) so we can fix it even sooner!
</hfoption>
</hfoptions>
## Editable install
You will need an editable install if you'd like to:
An editable install is recommended for development workflows or if you're using the `main` version of the source code. A special link is created between the cloned repository and the Python library paths. This avoids reinstalling a package after every change.
* Use the `main` version of the source code.
* Contribute to 🤗 Diffusers and need to test changes in the code.
Clone the repository and install Diffusers with the following commands.
Clone the repository and install 🤗 Diffusers with the following commands:
<hfoptions id="editable">
<hfoption id="PyTorch">
```bash
git clone https://github.com/huggingface/diffusers.git
cd diffusers
uv pip install -e ".[torch]"
```
<frameworkcontent>
<pt>
</hfoption>
<hfoption id="Flax">
```bash
pip install -e ".[torch]"
git clone https://github.com/huggingface/diffusers.git
cd diffusers
uv pip install -e ".[flax]"
```
</pt>
<jax>
```bash
pip install -e ".[flax]"
```
</jax>
</frameworkcontent>
These commands will link the folder you cloned the repository to and your Python library paths.
Python will now look inside the folder you cloned to in addition to the normal library paths.
For example, if your Python packages are typically installed in `~/anaconda3/envs/main/lib/python3.10/site-packages/`, Python will also search the `~/diffusers/` folder you cloned to.
</hfoption>
</hfoptions>
<Tip warning={true}>
> [!WARNING]
> You must keep the `diffusers` folder if you want to keep using the library with the editable install.
You must keep the `diffusers` folder if you want to keep using the library.
</Tip>
Now you can easily update your clone to the latest version of 🤗 Diffusers with the following command:
Update your cloned repository to the latest version of Diffusers with the command below.
```bash
cd ~/diffusers/
git pull
```
Your Python environment will find the `main` version of 🤗 Diffusers on the next run.
## Cache
Model weights and files are downloaded from the Hub to a cache which is usually your home directory. You can change the cache location by specifying the `HF_HOME` or `HUGGINFACE_HUB_CACHE` environment variables or configuring the `cache_dir` parameter in methods like [`~DiffusionPipeline.from_pretrained`].
Model weights and files are downloaded from the Hub to a cache, which is usually your home directory. Change the cache location with the [HF_HOME](https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables#hfhome) or [HF_HUB_CACHE](https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables#hfhubcache) environment variables or configuring the `cache_dir` parameter in methods like [`~DiffusionPipeline.from_pretrained`].
Cached files allow you to run 🤗 Diffusers offline. To prevent 🤗 Diffusers from connecting to the internet, set the `HF_HUB_OFFLINE` environment variable to `1` and 🤗 Diffusers will only load previously downloaded files in the cache.
<hfoptions id="cache">
<hfoption id="env variable">
```bash
export HF_HOME="/path/to/your/cache"
export HF_HUB_CACHE="/path/to/your/hub/cache"
```
</hfoption>
<hfoption id="from_pretrained">
```py
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
cache_dir="/path/to/your/cache"
)
```
</hfoption>
</hfoptions>
Cached files allow you to use Diffusers offline. Set the [HF_HUB_OFFLINE](https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables#hfhuboffline) environment variable to `1` to prevent Diffusers from connecting to the internet.
```shell
export HF_HUB_OFFLINE=1
```
For more details about managing and cleaning the cache, take a look at the [caching](https://huggingface.co/docs/huggingface_hub/guides/manage-cache) guide.
For more details about managing and cleaning the cache, take a look at the [Understand caching](https://huggingface.co/docs/huggingface_hub/guides/manage-cache) guide.
## Telemetry logging
Our library gathers telemetry information during [`~DiffusionPipeline.from_pretrained`] requests.
The data gathered includes the version of 🤗 Diffusers and PyTorch/Flax, the requested model or pipeline class,
and the path to a pretrained checkpoint if it is hosted on the Hugging Face Hub.
Diffusers gathers telemetry information during [`~DiffusionPipeline.from_pretrained`] requests.
The data gathered includes the Diffusers and PyTorch/Flax version, the requested model or pipeline class,
and the path to a pretrained checkpoint if it is hosted on the Hub.
This usage data helps us debug issues and prioritize new features.
Telemetry is only sent when loading models and pipelines from the Hub,
and it is not collected if you're loading local files.
We understand that not everyone wants to share additional information,and we respect your privacy.
You can disable telemetry collection by setting the `HF_HUB_DISABLE_TELEMETRY` environment variable from your terminal:
Opt-out and disable telemetry collection with the [HF_HUB_DISABLE_TELEMETRY](https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables#hfhubdisabletelemetry) environment variable.
On Linux/MacOS:
<hfoptions id="telemetry">
<hfoption id="Linux/macOS">
```bash
export HF_HUB_DISABLE_TELEMETRY=1
```
On Windows:
</hfoption>
<hfoption id="Windows">
```bash
set HF_HUB_DISABLE_TELEMETRY=1
```
</hfoption>
</hfoptions>
+10
View File
@@ -53,6 +53,16 @@ image = pipe(prompt, generator=torch.manual_seed(0)).images[0]
image.save("flux-gguf.png")
```
## Using Optimized CUDA Kernels with GGUF
Optimized CUDA kernels can accelerate GGUF quantized model inference by approximately 10%. This functionality requires a compatible GPU with `torch.cuda.get_device_capability` greater than 7 and the kernels library:
```shell
pip install -U kernels
```
Once installed, set `DIFFUSERS_GGUF_CUDA_KERNELS=true` to use optimized kernels when available. Note that CUDA kernels may introduce minor numerical differences compared to the original GGUF implementation, potentially causing subtle visual variations in generated images. To disable CUDA kernel usage, set the environment variable `DIFFUSERS_GGUF_CUDA_KERNELS=false`.
## Supported Quantization Types
- BF16
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# /// script
# dependencies = [
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# /// script
# dependencies = [
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# /// script
# dependencies = [
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import functools
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import functools
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
+1
View File
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import contextlib
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import contextlib
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import functools
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import itertools
+136
View File
@@ -0,0 +1,136 @@
# DreamBooth training example for Qwen Image
[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject.
The `train_dreambooth_lora_qwen_image.py` script shows how to implement the training procedure with [LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) and adapt it for [Qwen Image](https://huggingface.co/Qwen/Qwen-Image).
This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
## Running locally with PyTorch
### Installing the dependencies
Before running the scripts, make sure to install the library's training dependencies:
**Important**
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .
```
Then cd in the `examples/dreambooth` folder and run
```bash
pip install -r requirements_sana.txt
```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash
accelerate config
```
Or for a default accelerate configuration without answering questions about your environment
```bash
accelerate config default
```
Or if your environment doesn't support an interactive shell (e.g., a notebook)
```python
from accelerate.utils import write_basic_config
write_basic_config()
```
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.14.0` installed in your environment.
### Dog toy example
Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
Let's first download it locally:
```python
from huggingface_hub import snapshot_download
local_dir = "./dog"
snapshot_download(
"diffusers/dog-example",
local_dir=local_dir, repo_type="dataset",
ignore_patterns=".gitattributes",
)
```
This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
Now, we can launch training using:
```bash
export MODEL_NAME="Qwen/Qwen-Image"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-sana-lora"
accelerate launch train_dreambooth_lora_sana.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision="bf16" \
--instance_prompt="a photo of sks dog" \
--resolution=1024 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--use_8bit_adam \
--learning_rate=2e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25 \
--seed="0" \
--push_to_hub
```
For using `push_to_hub`, make you're logged into your Hugging Face account:
```bash
hf auth login
```
To better track our training experiments, we're using the following flags in the command above:
* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login <your_api_key>` before training if you haven't done it before.
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
## Notes
Additionally, we welcome you to explore the following CLI arguments:
* `--lora_layers`: The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - "to_k,to_q,to_v" will result in lora training of attention layers only.
* `--max_sequence_length`: Maximum sequence length to use for text embeddings.
We provide several options for optimizing memory optimization:
* `--offload`: When enabled, we will offload the text encoder and VAE to CPU, when they are not used.
* `cache_latents`: When enabled, we will pre-compute the latents from the input images with the VAE and remove the VAE from memory once done.
* `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library.
Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwenimage) of the `QwenImagePipeline` to know more about the models available under the SANA family and their preferred dtypes during inference.
## Using quantization
You can quantize the base model with [`bitsandbytes`](https://huggingface.co/docs/bitsandbytes/index) to reduce memory usage. To do so, pass a JSON file path to `--bnb_quantization_config_path`. This file should hold the configuration to initialize `BitsAndBytesConfig`. Below is an example JSON file:
```json
{
"load_in_4bit": true,
"bnb_4bit_quant_type": "nf4"
}
```
@@ -0,0 +1,248 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import os
import sys
import tempfile
import safetensors
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
class DreamBoothLoRAQwenImage(ExamplesTestsAccelerate):
instance_data_dir = "docs/source/en/imgs"
instance_prompt = "photo"
pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-pipe"
script_path = "examples/dreambooth/train_dreambooth_lora_qwen_image.py"
transformer_layer_type = "transformer_blocks.0.attn.to_k"
def test_dreambooth_lora_qwen(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_latent_caching(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--cache_latents
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_layers(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--cache_latents
--learning_rate 5.0e-04
--scale_lr
--lora_layers {self.transformer_layer_type}
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names. In this test, we only params of
# transformer.transformer_blocks.0.attn.to_k should be in the state dict
starts_with_transformer = all(
key.startswith(f"transformer.{self.transformer_layer_type}") for key in lora_state_dict.keys()
)
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_qwen_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=6
--checkpoints_total_limit=2
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-4", "checkpoint-6"},
)
def test_dreambooth_lora_qwen_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=4
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
resume_run_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=8
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-4
--checkpoints_total_limit=2
""".split()
run_command(self._launch_args + resume_run_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
def test_dreambooth_lora_with_metadata(self):
# Use a `lora_alpha` that is different from `rank`.
lora_alpha = 8
rank = 4
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--lora_alpha={lora_alpha}
--rank={rank}
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
self.assertTrue(os.path.isfile(state_dict_file))
# Check if the metadata was properly serialized.
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
metadata = f.metadata() or {}
metadata.pop("format", None)
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
if raw:
raw = json.loads(raw)
loaded_lora_alpha = raw["transformer.lora_alpha"]
self.assertTrue(loaded_lora_alpha == lora_alpha)
loaded_lora_rank = raw["transformer.r"]
self.assertTrue(loaded_lora_rank == rank)
+1
View File
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# /// script
# dependencies = [
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# /// script
# dependencies = [
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
File diff suppressed because it is too large Load Diff
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# /// script
# dependencies = [
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import gc
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import contextlib
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import functools
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import contextlib
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import contextlib
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import contextlib
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import contextlib
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
@@ -13,6 +13,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import io
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import contextlib
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import typing
@@ -10,6 +10,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
@@ -10,6 +10,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import functools
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
+1
View File
@@ -133,6 +133,7 @@ def _register_attention_processors_metadata():
skip_processor_output_fn=_skip_proc_output_fn_Attention_WanAttnProcessor2_0,
),
)
# FluxAttnProcessor
AttentionProcessorRegistry.register(
model_class=FluxAttnProcessor,
+95 -110
View File
@@ -95,7 +95,7 @@ class ModuleGroup:
self.offload_to_disk_path = offload_to_disk_path
self._is_offloaded_to_disk = False
if self.offload_to_disk_path:
if self.offload_to_disk_path is not None:
# Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well.
self.group_id = group_id if group_id is not None else str(id(self))
short_hash = _compute_group_hash(self.group_id)
@@ -115,6 +115,12 @@ class ModuleGroup:
else:
self.cpu_param_dict = self._init_cpu_param_dict()
self._torch_accelerator_module = (
getattr(torch, torch.accelerator.current_accelerator().type)
if hasattr(torch, "accelerator")
else torch.cuda
)
def _init_cpu_param_dict(self):
cpu_param_dict = {}
if self.stream is None:
@@ -138,112 +144,76 @@ class ModuleGroup:
@contextmanager
def _pinned_memory_tensors(self):
pinned_dict = {}
try:
for param, tensor in self.cpu_param_dict.items():
if not tensor.is_pinned():
pinned_dict[param] = tensor.pin_memory()
else:
pinned_dict[param] = tensor
pinned_dict = {
param: tensor.pin_memory() if not tensor.is_pinned() else tensor
for param, tensor in self.cpu_param_dict.items()
}
yield pinned_dict
finally:
pinned_dict = None
def _transfer_tensor_to_device(self, tensor, source_tensor, current_stream=None):
def _transfer_tensor_to_device(self, tensor, source_tensor):
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream and current_stream is not None:
tensor.data.record_stream(current_stream)
if self.record_stream:
tensor.data.record_stream(self._torch_accelerator_module.current_stream())
def _process_tensors_from_modules(self, pinned_memory=None, current_stream=None):
def _process_tensors_from_modules(self, pinned_memory=None):
for group_module in self.modules:
for param in group_module.parameters():
source = pinned_memory[param] if pinned_memory else param.data
self._transfer_tensor_to_device(param, source, current_stream)
self._transfer_tensor_to_device(param, source)
for buffer in group_module.buffers():
source = pinned_memory[buffer] if pinned_memory else buffer.data
self._transfer_tensor_to_device(buffer, source, current_stream)
self._transfer_tensor_to_device(buffer, source)
for param in self.parameters:
source = pinned_memory[param] if pinned_memory else param.data
self._transfer_tensor_to_device(param, source, current_stream)
self._transfer_tensor_to_device(param, source)
for buffer in self.buffers:
source = pinned_memory[buffer] if pinned_memory else buffer.data
self._transfer_tensor_to_device(buffer, source, current_stream)
def _onload_from_disk(self, current_stream):
if self.stream is not None:
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
for key, tensor_obj in self.key_to_tensor.items():
self.cpu_param_dict[tensor_obj] = loaded_cpu_tensors[key]
with self._pinned_memory_tensors() as pinned_memory:
for key, tensor_obj in self.key_to_tensor.items():
self._transfer_tensor_to_device(tensor_obj, pinned_memory[tensor_obj], current_stream)
self.cpu_param_dict.clear()
else:
onload_device = (
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
)
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
for key, tensor_obj in self.key_to_tensor.items():
tensor_obj.data = loaded_tensors[key]
def _onload_from_memory(self, current_stream):
if self.stream is not None:
with self._pinned_memory_tensors() as pinned_memory:
self._process_tensors_from_modules(pinned_memory, current_stream)
else:
self._process_tensors_from_modules(None, current_stream)
@torch.compiler.disable()
def onload_(self):
torch_accelerator_module = (
getattr(torch, torch.accelerator.current_accelerator().type)
if hasattr(torch, "accelerator")
else torch.cuda
)
context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream)
current_stream = torch_accelerator_module.current_stream() if self.record_stream else None
if self.offload_to_disk_path:
if self.stream is not None:
# Wait for previous Host->Device transfer to complete
self.stream.synchronize()
with context:
if self.stream is not None:
# Load to CPU, pin, and async copy to device for overlapping transfer and compute
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
for key, tensor_obj in self.key_to_tensor.items():
pinned_tensor = loaded_cpu_tensors[key].pin_memory()
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
tensor_obj.data.record_stream(current_stream)
else:
# Load directly to the target device (synchronous)
onload_device = (
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
)
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
for key, tensor_obj in self.key_to_tensor.items():
tensor_obj.data = loaded_tensors[key]
return
self._transfer_tensor_to_device(buffer, source)
def _onload_from_disk(self):
if self.stream is not None:
# Wait for previous Host->Device transfer to complete
self.stream.synchronize()
context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
current_stream = self._torch_accelerator_module.current_stream() if self.record_stream else None
with context:
if self.offload_to_disk_path:
self._onload_from_disk(current_stream)
# Load to CPU (if using streams) or directly to target device, pin, and async copy to device
device = str(self.onload_device) if self.stream is None else "cpu"
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device)
if self.stream is not None:
for key, tensor_obj in self.key_to_tensor.items():
pinned_tensor = loaded_tensors[key].pin_memory()
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
tensor_obj.data.record_stream(current_stream)
else:
self._onload_from_memory(current_stream)
onload_device = (
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
)
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
for key, tensor_obj in self.key_to_tensor.items():
tensor_obj.data = loaded_tensors[key]
def _onload_from_memory(self):
if self.stream is not None:
# Wait for previous Host->Device transfer to complete
self.stream.synchronize()
context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
with context:
if self.stream is not None:
with self._pinned_memory_tensors() as pinned_memory:
self._process_tensors_from_modules(pinned_memory)
else:
self._process_tensors_from_modules(None)
def _offload_to_disk(self):
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
@@ -264,14 +234,10 @@ class ModuleGroup:
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
def _offload_to_memory(self):
torch_accelerator_module = (
getattr(torch, torch.accelerator.current_accelerator().type)
if hasattr(torch, "accelerator")
else torch.cuda
)
if self.stream is not None:
if not self.record_stream:
torch_accelerator_module.current_stream().synchronize()
self._torch_accelerator_module.current_stream().synchronize()
for group_module in self.modules:
for param in group_module.parameters():
param.data = self.cpu_param_dict[param]
@@ -279,18 +245,25 @@ class ModuleGroup:
param.data = self.cpu_param_dict[param]
for buffer in self.buffers:
buffer.data = self.cpu_param_dict[buffer]
else:
for group_module in self.modules:
group_module.to(self.offload_device, non_blocking=self.non_blocking)
group_module.to(self.offload_device, non_blocking=False)
for param in self.parameters:
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
param.data = param.data.to(self.offload_device, non_blocking=False)
for buffer in self.buffers:
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
@torch.compiler.disable()
def onload_(self):
r"""Onloads the group of parameters to the onload_device."""
if self.offload_to_disk_path is not None:
self._onload_from_disk()
else:
self._onload_from_memory()
@torch.compiler.disable()
def offload_(self):
r"""Offloads the group of modules to the offload_device."""
r"""Offloads the group of parameters to the offload_device."""
if self.offload_to_disk_path:
self._offload_to_disk()
else:
@@ -307,11 +280,9 @@ class GroupOffloadingHook(ModelHook):
_is_stateful = False
def __init__(
self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None, *, config: GroupOffloadingConfig
) -> None:
def __init__(self, group: ModuleGroup, *, config: GroupOffloadingConfig) -> None:
self.group = group
self.next_group = next_group
self.next_group: Optional[ModuleGroup] = None
self.config = config
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
@@ -331,9 +302,23 @@ class GroupOffloadingHook(ModelHook):
if self.group.onload_leader == module:
if self.group.onload_self:
self.group.onload_()
if self.next_group is not None and not self.next_group.onload_self:
should_onload_next_group = self.next_group is not None and not self.next_group.onload_self
if should_onload_next_group:
self.next_group.onload_()
should_synchronize = (
not self.group.onload_self and self.group.stream is not None and not should_onload_next_group
)
if should_synchronize:
# If this group didn't onload itself, it means it was asynchronously onloaded by the
# previous group. We need to synchronize the side stream to ensure parameters
# are completely loaded to proceed with forward pass. Without this, uninitialized
# weights will be used in the computation, leading to incorrect results
# Also, we should only do this synchronization if we don't already do it from the sync call in
# self.next_group.onload_, hence the `not should_onload_next_group` check.
self.group.stream.synchronize()
args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
return args, kwargs
@@ -459,8 +444,8 @@ class LayerExecutionTrackerHook(ModelHook):
def apply_group_offloading(
module: torch.nn.Module,
onload_device: torch.device,
offload_device: torch.device = torch.device("cpu"),
onload_device: Union[str, torch.device],
offload_device: Union[str, torch.device] = torch.device("cpu"),
offload_type: Union[str, GroupOffloadingType] = "block_level",
num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False,
@@ -546,6 +531,8 @@ def apply_group_offloading(
```
"""
onload_device = torch.device(onload_device) if isinstance(onload_device, str) else onload_device
offload_device = torch.device(offload_device) if isinstance(offload_device, str) else offload_device
offload_type = GroupOffloadingType(offload_type)
stream = None
@@ -633,7 +620,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
# Apply group offloading hooks to the module groups
for i, group in enumerate(matched_module_groups):
for group_module in group.modules:
_apply_group_offloading_hook(group_module, group, None, config=config)
_apply_group_offloading_hook(group_module, group, config=config)
# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
# when the forward pass of this module is called. This is because the top-level module is not
@@ -662,9 +649,9 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
group_id=f"{module.__class__.__name__}_unmatched_group",
)
if config.stream is None:
_apply_group_offloading_hook(module, unmatched_group, None, config=config)
_apply_group_offloading_hook(module, unmatched_group, config=config)
else:
_apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
@@ -693,7 +680,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
onload_self=True,
group_id=name,
)
_apply_group_offloading_hook(submodule, group, None, config=config)
_apply_group_offloading_hook(submodule, group, config=config)
modules_with_group_offloading.add(name)
# Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
@@ -740,7 +727,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
onload_self=True,
group_id=name,
)
_apply_group_offloading_hook(parent_module, group, None, config=config)
_apply_group_offloading_hook(parent_module, group, config=config)
if config.stream is not None:
# When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
@@ -762,13 +749,12 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
onload_self=True,
group_id=_GROUP_ID_LAZY_LEAF,
)
_apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
def _apply_group_offloading_hook(
module: torch.nn.Module,
group: ModuleGroup,
next_group: Optional[ModuleGroup] = None,
*,
config: GroupOffloadingConfig,
) -> None:
@@ -777,14 +763,13 @@ def _apply_group_offloading_hook(
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
if registry.get_hook(_GROUP_OFFLOADING) is None:
hook = GroupOffloadingHook(group, next_group, config=config)
hook = GroupOffloadingHook(group, config=config)
registry.register_hook(hook, _GROUP_OFFLOADING)
def _apply_lazy_group_offloading_hook(
module: torch.nn.Module,
group: ModuleGroup,
next_group: Optional[ModuleGroup] = None,
*,
config: GroupOffloadingConfig,
) -> None:
@@ -793,7 +778,7 @@ def _apply_lazy_group_offloading_hook(
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
if registry.get_hook(_GROUP_OFFLOADING) is None:
hook = GroupOffloadingHook(group, next_group, config=config)
hook = GroupOffloadingHook(group, config=config)
registry.register_hook(hook, _GROUP_OFFLOADING)
lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook()
+43
View File
@@ -0,0 +1,43 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES
def _get_identifiable_transformer_blocks_in_module(module: torch.nn.Module):
module_list_with_transformer_blocks = []
for name, submodule in module.named_modules():
name_endswith_identifier = any(name.endswith(identifier) for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS)
is_modulelist = isinstance(submodule, torch.nn.ModuleList)
if name_endswith_identifier and is_modulelist:
module_list_with_transformer_blocks.append((name, submodule))
return module_list_with_transformer_blocks
def _get_identifiable_attention_layers_in_module(module: torch.nn.Module):
attention_layers = []
for name, submodule in module.named_modules():
if isinstance(submodule, _ATTENTION_CLASSES):
attention_layers.append((name, submodule))
return attention_layers
def _get_identifiable_feedforward_layers_in_module(module: torch.nn.Module):
feedforward_layers = []
for name, submodule in module.named_modules():
if isinstance(submodule, _FEEDFORWARD_CLASSES):
feedforward_layers.append((name, submodule))
return feedforward_layers
+2
View File
@@ -79,6 +79,7 @@ if is_torch_available():
"WanLoraLoaderMixin",
"HiDreamImageLoraLoaderMixin",
"SkyReelsV2LoraLoaderMixin",
"QwenImageLoraLoaderMixin",
]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = [
@@ -118,6 +119,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LTXVideoLoraLoaderMixin,
Lumina2LoraLoaderMixin,
Mochi1LoraLoaderMixin,
QwenImageLoraLoaderMixin,
SanaLoraLoaderMixin,
SD3LoraLoaderMixin,
SkyReelsV2LoraLoaderMixin,
+342
View File
@@ -6538,6 +6538,348 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
super().unfuse_lora(components=components, **kwargs)
class QwenImageLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`QwenImageTransformer2DModel`]. Specific to [`QwenImagePipeline`].
"""
_lora_loadable_modules = ["transformer"]
transformer_name = TRANSFORMER_NAME
@classmethod
@validate_hf_hub_args
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
Return state dict for lora weights and the network alphas.
<Tip warning={true}>
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
This function is experimental and might change in the future.
</Tip>
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
return_lora_metadata (`bool`, *optional*, defaults to False):
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
local_files_only=local_files_only,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
)
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
hotswap: bool = False,
**kwargs,
):
"""
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
dict is loaded into `self.transformer`.
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
kwargs (`dict`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# if a dict is passed, copy it instead of modifying it inplace
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->QwenImageTransformer2DModel
def load_lora_into_transformer(
cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
Parameters:
state_dict (`dict`):
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
encoder lora layers.
transformer (`QwenImageTransformer2DModel`):
The Transformer model to load the LoRA layers into.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# Load the layers corresponding to transformer.
logger.info(f"Loading {cls.transformer_name}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=None,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
Save the LoRA parameters corresponding to the transformer.
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to save LoRA parameters to. Will be created if it doesn't exist.
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `transformer`.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful during distributed training and you
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
process to avoid race conditions.
save_function (`Callable`):
The function to use to save the state dictionary. Useful during distributed training when you need to
replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
"""
state_dict = {}
lora_adapter_metadata = {}
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
def fuse_lora(
self,
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
<Tip warning={true}>
This is an experimental API.
</Tip>
Args:
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
lora_scale (`float`, defaults to 1.0):
Controls how much to influence the outputs with the LoRA parameters.
safe_fusing (`bool`, defaults to `False`):
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
adapter_names (`List[str]`, *optional*):
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
Example:
```py
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
pipeline.fuse_lora(lora_scale=0.7)
```
"""
super().fuse_lora(
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
<Tip warning={true}>
This is an experimental API.
</Tip>
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components, **kwargs)
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
def __init__(self, *args, **kwargs):
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
+1
View File
@@ -61,6 +61,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
"HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights,
"WanVACETransformer3DModel": lambda model_cls, weights: weights,
"ChromaTransformer2DModel": lambda model_cls, weights: weights,
"QwenImageTransformer2DModel": lambda model_cls, weights: weights,
}
@@ -312,15 +312,14 @@ class AudioLDM2Pipeline(DiffusionPipeline):
The sequence of generated hidden-states.
"""
cache_position_kwargs = {}
if is_transformers_version("<", "4.52.0.dev0"):
if is_transformers_version("<", "4.52.1"):
cache_position_kwargs["input_ids"] = inputs_embeds
cache_position_kwargs["model_kwargs"] = model_kwargs
else:
cache_position_kwargs["seq_length"] = inputs_embeds.shape[0]
cache_position_kwargs["device"] = (
self.language_model.device if getattr(self, "language_model", None) is not None else self.device
)
cache_position_kwargs["model_kwargs"] = model_kwargs
cache_position_kwargs["model_kwargs"] = model_kwargs
max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens
model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs)
@@ -310,7 +310,7 @@ class FluxPipeline(
def encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -324,7 +324,7 @@ class FluxControlPipeline(
def encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -335,7 +335,7 @@ class FluxControlImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSin
def encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -374,7 +374,7 @@ class FluxControlInpaintPipeline(
def encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -341,7 +341,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
def encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -335,7 +335,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
def encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -346,7 +346,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
def encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -419,7 +419,7 @@ class FluxFillPipeline(
def encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -333,7 +333,7 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
def encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -337,7 +337,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterM
def encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -358,7 +358,7 @@ class FluxKontextPipeline(
def encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -391,7 +391,7 @@ class FluxKontextInpaintPipeline(
def encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -292,7 +292,7 @@ class FluxPriorReduxPipeline(DiffusionPipeline):
def encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -20,6 +20,7 @@ import torch
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
from ...image_processor import VaeImageProcessor
from ...loaders import QwenImageLoraLoaderMixin
from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
@@ -44,12 +45,12 @@ EXAMPLE_DOC_STRING = """
>>> import torch
>>> from diffusers import QwenImagePipeline
>>> pipe = QwenImagePipeline.from_pretrained("Qwen/QwenImage-20B", torch_dtype=torch.bfloat16)
>>> pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> prompt = "A cat holding a sign that says hello world"
>>> # Depending on the variant being used, the pipeline call will slightly vary.
>>> # Refer to the pipeline documentation for more details.
>>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
>>> image = pipe(prompt, num_inference_steps=50).images[0]
>>> image.save("qwenimage.png")
```
"""
@@ -128,7 +129,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
class QwenImagePipeline(DiffusionPipeline):
class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
r"""
The QwenImage pipeline for text-to-image generation.
@@ -635,6 +636,11 @@ class QwenImagePipeline(DiffusionPipeline):
if self.attention_kwargs is None:
self._attention_kwargs = {}
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
negative_txt_seq_lens = (
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
)
# 6. Denoising loop
self.scheduler.set_begin_index(0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -653,7 +659,7 @@ class QwenImagePipeline(DiffusionPipeline):
encoder_hidden_states_mask=prompt_embeds_mask,
encoder_hidden_states=prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
txt_seq_lens=txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
@@ -667,7 +673,7 @@ class QwenImagePipeline(DiffusionPipeline):
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(),
txt_seq_lens=negative_txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
+92 -3
View File
@@ -12,15 +12,15 @@
# # See the License for the specific language governing permissions and
# # limitations under the License.
import inspect
import os
from contextlib import nullcontext
import gguf
import torch
import torch.nn as nn
from ...utils import is_accelerate_available
from ...utils import is_accelerate_available, is_kernels_available
if is_accelerate_available():
@@ -29,6 +29,82 @@ if is_accelerate_available():
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
can_use_cuda_kernels = (
os.getenv("DIFFUSERS_GGUF_CUDA_KERNELS", "false").lower() in ["1", "true", "yes"]
and torch.cuda.is_available()
and torch.cuda.get_device_capability()[0] >= 7
)
if can_use_cuda_kernels and is_kernels_available():
from kernels import get_kernel
ops = get_kernel("Isotr0py/ggml")
else:
ops = None
UNQUANTIZED_TYPES = {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.BF16}
STANDARD_QUANT_TYPES = {
gguf.GGMLQuantizationType.Q4_0,
gguf.GGMLQuantizationType.Q4_1,
gguf.GGMLQuantizationType.Q5_0,
gguf.GGMLQuantizationType.Q5_1,
gguf.GGMLQuantizationType.Q8_0,
gguf.GGMLQuantizationType.Q8_1,
}
KQUANT_TYPES = {
gguf.GGMLQuantizationType.Q2_K,
gguf.GGMLQuantizationType.Q3_K,
gguf.GGMLQuantizationType.Q4_K,
gguf.GGMLQuantizationType.Q5_K,
gguf.GGMLQuantizationType.Q6_K,
}
IMATRIX_QUANT_TYPES = {
gguf.GGMLQuantizationType.IQ1_M,
gguf.GGMLQuantizationType.IQ1_S,
gguf.GGMLQuantizationType.IQ2_XXS,
gguf.GGMLQuantizationType.IQ2_XS,
gguf.GGMLQuantizationType.IQ2_S,
gguf.GGMLQuantizationType.IQ3_XXS,
gguf.GGMLQuantizationType.IQ3_S,
gguf.GGMLQuantizationType.IQ4_XS,
gguf.GGMLQuantizationType.IQ4_NL,
}
# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization.
# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add
# MMQ kernel for I-Matrix quantization.
DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: int) -> torch.Tensor:
# there is no need to call any kernel for fp16/bf16
if qweight_type in UNQUANTIZED_TYPES:
return x @ qweight.T
# TODO(Isotr0py): GGUF's MMQ and MMVQ implementation are designed for
# contiguous batching and inefficient with diffusers' batching,
# so we disabled it now.
# elif qweight_type in MMVQ_QUANT_TYPES:
# y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
# elif qweight_type in MMQ_QUANT_TYPES:
# y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
# If there is no available MMQ kernel, fallback to dequantize
if qweight_type in DEQUANT_TYPES:
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
weight = ops.ggml_dequantize(qweight, qweight_type, *shape)
y = x @ weight.to(x.dtype).T
else:
# Raise an error if the quantization type is not supported.
# Might be useful if llama.cpp adds a new quantization type.
# Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type.
qweight_type = gguf.GGMLQuantizationType(qweight_type)
raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}")
return y.as_tensor()
# Copied from diffusers.quantizers.bitsandbytes.utils._create_accelerate_new_hook
def _create_accelerate_new_hook(old_hook):
r"""
@@ -451,11 +527,24 @@ class GGUFLinear(nn.Linear):
) -> None:
super().__init__(in_features, out_features, bias, device)
self.compute_dtype = compute_dtype
self.device = device
def forward(self, inputs):
def forward(self, inputs: torch.Tensor):
if ops is not None and self.weight.is_cuda and inputs.is_cuda:
return self.forward_cuda(inputs)
return self.forward_native(inputs)
def forward_native(self, inputs: torch.Tensor):
weight = dequantize_gguf_tensor(self.weight)
weight = weight.to(self.compute_dtype)
bias = self.bias.to(self.compute_dtype) if self.bias is not None else None
output = torch.nn.functional.linear(inputs, weight, bias)
return output
def forward_cuda(self, inputs: torch.Tensor):
quant_type = self.weight.quant_type
output = _fused_mul_mat_gguf(inputs.to(self.compute_dtype), self.weight, quant_type)
if self.bias is not None:
output += self.bias.to(self.compute_dtype)
return output
+1
View File
@@ -81,6 +81,7 @@ from .import_utils import (
is_invisible_watermark_available,
is_k_diffusion_available,
is_k_diffusion_version,
is_kernels_available,
is_librosa_available,
is_matplotlib_available,
is_nltk_available,
+5
View File
@@ -192,6 +192,7 @@ _torch_xla_available, _torch_xla_version = _is_package_available("torch_xla")
_torch_npu_available, _torch_npu_version = _is_package_available("torch_npu")
_transformers_available, _transformers_version = _is_package_available("transformers")
_hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub")
_kernels_available, _kernels_version = _is_package_available("kernels")
_inflect_available, _inflect_version = _is_package_available("inflect")
_unidecode_available, _unidecode_version = _is_package_available("unidecode")
_k_diffusion_available, _k_diffusion_version = _is_package_available("k_diffusion")
@@ -277,6 +278,10 @@ def is_accelerate_available():
return _accelerate_available
def is_kernels_available():
return _kernels_available
def is_k_diffusion_available():
return _k_diffusion_available
+13
View File
@@ -36,6 +36,7 @@ from .import_utils import (
is_compel_available,
is_flax_available,
is_gguf_available,
is_kernels_available,
is_note_seq_available,
is_onnx_available,
is_opencv_available,
@@ -634,6 +635,18 @@ def require_torchao_version_greater_or_equal(torchao_version):
return decorator
def require_kernels_version_greater_or_equal(kernels_version):
def decorator(test_case):
correct_kernels_version = is_kernels_available() and version.parse(
version.parse(importlib.metadata.version("kernels")).base_version
) >= version.parse(kernels_version)
return unittest.skipUnless(
correct_kernels_version, f"Test requires kernels with version greater than {kernels_version}."
)(test_case)
return decorator
def deprecate_after_peft_backend(test_case):
"""
Decorator marking a test that will be skipped after PEFT backend
+1
View File
@@ -10,6 +10,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
+1
View File
@@ -10,6 +10,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
+87
View File
@@ -17,7 +17,9 @@ import gc
import unittest
import torch
from parameterized import parameterized
from diffusers.hooks import HookRegistry, ModelHook
from diffusers.models import ModelMixin
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.utils import get_logger
@@ -99,6 +101,29 @@ class DummyModelWithMultipleBlocks(ModelMixin):
return x
# Test for https://github.com/huggingface/diffusers/pull/12077
class DummyModelWithLayerNorm(ModelMixin):
def __init__(self, in_features: int, hidden_features: int, out_features: int, num_layers: int) -> None:
super().__init__()
self.linear_1 = torch.nn.Linear(in_features, hidden_features)
self.activation = torch.nn.ReLU()
self.blocks = torch.nn.ModuleList(
[DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)]
)
self.layer_norm = torch.nn.LayerNorm(hidden_features, elementwise_affine=True)
self.linear_2 = torch.nn.Linear(hidden_features, out_features)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear_1(x)
x = self.activation(x)
for block in self.blocks:
x = block(x)
x = self.layer_norm(x)
x = self.linear_2(x)
return x
class DummyPipeline(DiffusionPipeline):
model_cpu_offload_seq = "model"
@@ -113,6 +138,16 @@ class DummyPipeline(DiffusionPipeline):
return x
class LayerOutputTrackerHook(ModelHook):
def __init__(self):
super().__init__()
self.outputs = []
def post_forward(self, module, output):
self.outputs.append(output)
return output
@require_torch_accelerator
class GroupOffloadTests(unittest.TestCase):
in_features = 64
@@ -258,6 +293,7 @@ class GroupOffloadTests(unittest.TestCase):
def test_block_level_stream_with_invocation_order_different_from_initialization_order(self):
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return
model = DummyModelWithMultipleBlocks(
in_features=self.in_features,
hidden_features=self.hidden_features,
@@ -274,3 +310,54 @@ class GroupOffloadTests(unittest.TestCase):
with context:
model(self.input)
@parameterized.expand([("block_level",), ("leaf_level",)])
def test_block_level_offloading_with_parameter_only_module_group(self, offload_type: str):
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return
def apply_layer_output_tracker_hook(model: DummyModelWithLayerNorm):
for name, module in model.named_modules():
registry = HookRegistry.check_if_exists_or_initialize(module)
hook = LayerOutputTrackerHook()
registry.register_hook(hook, "layer_output_tracker")
model_ref = DummyModelWithLayerNorm(128, 256, 128, 2)
model = DummyModelWithLayerNorm(128, 256, 128, 2)
model.load_state_dict(model_ref.state_dict(), strict=True)
model_ref.to(torch_device)
model.enable_group_offload(torch_device, offload_type=offload_type, num_blocks_per_group=1, use_stream=True)
apply_layer_output_tracker_hook(model_ref)
apply_layer_output_tracker_hook(model)
x = torch.randn(2, 128).to(torch_device)
out_ref = model_ref(x)
out = model(x)
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match.")
num_repeats = 4
for i in range(num_repeats):
out_ref = model_ref(x)
out = model(x)
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match after multiple invocations.")
for (ref_name, ref_module), (name, module) in zip(model_ref.named_modules(), model.named_modules()):
assert ref_name == name
ref_outputs = (
HookRegistry.check_if_exists_or_initialize(ref_module).get_hook("layer_output_tracker").outputs
)
outputs = HookRegistry.check_if_exists_or_initialize(module).get_hook("layer_output_tracker").outputs
cumulated_absmax = 0.0
for i in range(len(outputs)):
diff = ref_outputs[0] - outputs[i]
absdiff = diff.abs()
absmax = absdiff.max().item()
cumulated_absmax += absmax
self.assertLess(
cumulated_absmax, 1e-5, f"Output differences for {name} exceeded threshold: {cumulated_absmax:.5f}"
)
+129
View File
@@ -0,0 +1,129 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
from diffusers import (
AutoencoderKLQwenImage,
FlowMatchEulerDiscreteScheduler,
QwenImagePipeline,
QwenImageTransformer2DModel,
)
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend
sys.path.append(".")
from utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = QwenImagePipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
"patch_size": 2,
"in_channels": 16,
"out_channels": 4,
"num_layers": 2,
"attention_head_dim": 16,
"num_attention_heads": 3,
"joint_attention_dim": 16,
"guidance_embeds": False,
"axes_dims_rope": (8, 4, 4),
}
transformer_cls = QwenImageTransformer2DModel
z_dim = 4
vae_kwargs = {
"base_dim": z_dim * 6,
"z_dim": z_dim,
"dim_mult": [1, 2, 4],
"num_res_blocks": 1,
"temperal_downsample": [False, True],
"latents_mean": [0.0] * 4,
"latents_std": [1.0] * 4,
}
vae_cls = AutoencoderKLQwenImage
tokenizer_cls, tokenizer_id = Qwen2Tokenizer, "hf-internal-testing/tiny-random-Qwen25VLForCondGen"
text_encoder_cls, text_encoder_id = (
Qwen2_5_VLForConditionalGeneration,
"hf-internal-testing/tiny-random-Qwen25VLForCondGen",
)
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
@property
def output_shape(self):
return (1, 8, 8, 3)
def get_dummy_inputs(self, with_generator=True):
batch_size = 1
sequence_length = 10
num_channels = 4
sizes = (32, 32)
generator = torch.manual_seed(0)
noise = floats_tensor((batch_size, num_channels) + sizes)
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
pipeline_inputs = {
"prompt": "A painting of a squirrel eating a burger",
"num_inference_steps": 4,
"guidance_scale": 0.0,
"height": 8,
"width": 8,
"output_type": "np",
}
if with_generator:
pipeline_inputs.update({"generator": generator})
return noise, input_ids, pipeline_inputs
@unittest.skip("Not supported in Qwen Image.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in Qwen Image.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in Qwen Image.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
def test_simple_inference_with_text_lora_save_load(self):
pass
+11 -1
View File
@@ -45,6 +45,7 @@ from diffusers import (
LMSDiscreteScheduler,
PNDMScheduler,
)
from diffusers.utils import is_transformers_version
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
@@ -220,6 +221,11 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
}
return inputs
@pytest.mark.xfail(
condition=is_transformers_version(">=", "4.54.1"),
reason="Test currently fails on Transformers version 4.54.1.",
strict=False,
)
def test_audioldm2_ddim(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -312,7 +318,6 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
components = self.get_dummy_components()
audioldm_pipe = AudioLDM2Pipeline(**components)
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
@@ -371,6 +376,11 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
assert np.abs(audio_1 - audio_2).max() < 1e-2
@pytest.mark.xfail(
condition=is_transformers_version(">=", "4.54.1"),
reason="Test currently fails on Transformers version 4.54.1.",
strict=False,
)
def test_audioldm2_negative_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
+1 -1
View File
@@ -160,7 +160,7 @@ class QwenImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
self.assertEqual(generated_image.shape, (3, 32, 32))
# fmt: off
expected_slice = torch.tensor([0.563, 0.6358, 0.6028, 0.5656, 0.5806, 0.5512, 0.5712, 0.6331, 0.4147, 0.3558, 0.5625, 0.4831, 0.4957, 0.5258, 0.4075, 0.5018])
expected_slice = torch.tensor([0.56331, 0.63677, 0.6015, 0.56369, 0.58166, 0.55277, 0.57176, 0.63261, 0.41466, 0.35561, 0.56229, 0.48334, 0.49714, 0.52622, 0.40872, 0.50208])
# fmt: on
generated_slice = generated_image.flatten()
+58 -1
View File
@@ -30,8 +30,10 @@ from diffusers.utils.testing_utils import (
nightly,
numpy_cosine_similarity_distance,
require_accelerate,
require_accelerator,
require_big_accelerator,
require_gguf_version_greater_or_equal,
require_kernels_version_greater_or_equal,
require_peft_backend,
require_torch_version_greater,
torch_device,
@@ -41,11 +43,66 @@ from ..test_torch_compile_utils import QuantCompileTests
if is_gguf_available():
import gguf
from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter
enable_full_determinism()
@nightly
@require_accelerate
@require_accelerator
@require_gguf_version_greater_or_equal("0.10.0")
@require_kernels_version_greater_or_equal("0.9.0")
class GGUFCudaKernelsTests(unittest.TestCase):
def setUp(self):
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
gc.collect()
backend_empty_cache(torch_device)
def test_cuda_kernels_vs_native(self):
if torch_device != "cuda":
self.skipTest("CUDA kernels test requires CUDA device")
from diffusers.quantizers.gguf.utils import GGUFLinear, can_use_cuda_kernels
if not can_use_cuda_kernels:
self.skipTest("CUDA kernels not available (compute capability < 7 or kernels not installed)")
test_quant_types = ["Q4_0", "Q4_K"]
test_shape = (1, 64, 512) # batch, seq_len, hidden_dim
compute_dtype = torch.bfloat16
for quant_type in test_quant_types:
qtype = getattr(gguf.GGMLQuantizationType, quant_type)
in_features, out_features = 512, 512
torch.manual_seed(42)
float_weight = torch.randn(out_features, in_features, dtype=torch.float32)
quantized_data = gguf.quants.quantize(float_weight.numpy(), qtype)
weight_data = torch.from_numpy(quantized_data).to(device=torch_device)
weight = GGUFParameter(weight_data, quant_type=qtype)
x = torch.randn(test_shape, dtype=compute_dtype, device=torch_device)
linear = GGUFLinear(in_features, out_features, bias=True, compute_dtype=compute_dtype)
linear.weight = weight
linear.bias = nn.Parameter(torch.randn(out_features, dtype=compute_dtype))
linear = linear.to(torch_device)
with torch.no_grad():
output_native = linear.forward_native(x)
output_cuda = linear.forward_cuda(x)
assert torch.allclose(output_native, output_cuda, 1e-2), (
f"GGUF CUDA Kernel Output is different from Native Output for {quant_type}"
)
@nightly
@require_big_accelerator
@require_accelerate
@@ -593,7 +650,7 @@ class WanGGUFTexttoVideoSingleFileTests(GGUFSingleFileTesterMixin, unittest.Test
def get_dummy_inputs(self):
return {
"hidden_states": torch.randn((1, 36, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
"hidden_states": torch.randn((1, 16, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
torch_device, self.torch_dtype
),
"encoder_hidden_states": torch.randn(