Compare commits

..

2 Commits

Author SHA1 Message Date
Sayak Paul 96ef4bf62a Merge branch 'main' into enable-all-gpus 2025-08-04 19:09:32 +05:30
sayakpaul 866740a42b enable all gpus when running ci. 2025-08-04 16:36:06 +05:30
33 changed files with 240 additions and 3977 deletions
+4 -4
View File
@@ -178,7 +178,7 @@ jobs:
container:
image: diffusers/diffusers-pytorch-cuda
options: --gpus all --shm-size "16gb" --ipc host
options: --gpus 0 --shm-size "16gb" --ipc host
steps:
- name: Checkout diffusers
@@ -333,7 +333,7 @@ jobs:
additional_deps: ["peft"]
- backend: "gguf"
test_location: "gguf"
additional_deps: ["peft", "kernels"]
additional_deps: ["peft"]
- backend: "torchao"
test_location: "torchao"
additional_deps: []
@@ -344,7 +344,7 @@ jobs:
group: aws-g6e-xlarge-plus
container:
image: diffusers/diffusers-pytorch-cuda
options: --shm-size "20gb" --ipc host --gpus all
options: --shm-size "20gb" --ipc host --gpus 0
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
@@ -396,7 +396,7 @@ jobs:
group: aws-g6e-xlarge-plus
container:
image: diffusers/diffusers-pytorch-cuda
options: --shm-size "20gb" --ipc host --gpus all
options: --shm-size "20gb" --ipc host --gpus 0
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
+1 -1
View File
@@ -253,7 +253,7 @@ jobs:
container:
image: diffusers/diffusers-pytorch-cuda
options: --gpus all --shm-size "16gb" --ipc host
options: --gpus 0 --shm-size "16gb" --ipc host
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
+3 -3
View File
@@ -167,7 +167,7 @@ jobs:
container:
image: diffusers/diffusers-pytorch-cuda
options: --gpus all --shm-size "16gb" --ipc host
options: --gpus 0 --shm-size "16gb" --ipc host
steps:
- name: Checkout diffusers
@@ -210,7 +210,7 @@ jobs:
container:
image: diffusers/diffusers-pytorch-xformers-cuda
options: --gpus all --shm-size "16gb" --ipc host
options: --gpus 0 --shm-size "16gb" --ipc host
steps:
- name: Checkout diffusers
@@ -252,7 +252,7 @@ jobs:
container:
image: diffusers/diffusers-pytorch-cuda
options: --gpus all --shm-size "16gb" --ipc host
options: --gpus 0 --shm-size "16gb" --ipc host
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
+3 -3
View File
@@ -222,7 +222,7 @@ jobs:
container:
image: diffusers/diffusers-pytorch-cuda
options: --gpus all --shm-size "16gb" --ipc host
options: --gpus 0 --shm-size "16gb" --ipc host
steps:
- name: Checkout diffusers
@@ -265,7 +265,7 @@ jobs:
container:
image: diffusers/diffusers-pytorch-xformers-cuda
options: --gpus all --shm-size "16gb" --ipc host
options: --gpus 0 --shm-size "16gb" --ipc host
steps:
- name: Checkout diffusers
@@ -307,7 +307,7 @@ jobs:
container:
image: diffusers/diffusers-pytorch-cuda
options: --gpus all --shm-size "16gb" --ipc host
options: --gpus 0 --shm-size "16gb" --ipc host
steps:
- name: Checkout diffusers
+1 -1
View File
@@ -30,7 +30,7 @@ jobs:
group: aws-g4dn-2xlarge
container:
image: ${{ github.event.inputs.docker_image }}
options: --gpus all --privileged --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
options: --gpus 0 --privileged --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
steps:
- name: Validate test files input
-5
View File
@@ -30,7 +30,6 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
- [`CogView4LoraLoaderMixin`] provides similar functions for [CogView4](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogview4).
- [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`].
- [`HiDreamImageLoraLoaderMixin`] provides similar functions for [HiDream Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hidream)
- [`QwenImageLoraLoaderMixin`] provides similar functions for [Qwen Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwen)
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.
<Tip>
@@ -106,10 +105,6 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
[[autodoc]] loaders.lora_pipeline.HiDreamImageLoraLoaderMixin
## QwenImageLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.QwenImageLoraLoaderMixin
## LoraBaseMixin
[[autodoc]] loaders.lora_base.LoraBaseMixin
+2 -4
View File
@@ -14,9 +14,7 @@
# QwenImage
Qwen-Image from the Qwen team is an image generation foundation model in the Qwen series that achieves significant advances in complex text rendering and precise image editing. Experiments show strong general capabilities in both image generation and editing, with exceptional performance in text rendering, especially for Chinese.
Check out the model card [here](https://huggingface.co/Qwen/Qwen-Image) to learn more.
<!-- TODO: update this section when model is out -->
<Tip>
@@ -30,6 +28,6 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
- all
- __call__
## QwenImagePipelineOutput
## QwenImagePipeline
[[autodoc]] pipelines.qwenimage.pipeline_output.QwenImagePipelineOutput
+102 -75
View File
@@ -12,156 +12,183 @@ specific language governing permissions and limitations under the License.
# Installation
Diffusers is tested on Python 3.8+, PyTorch 1.4+, and Flax 0.4.1+. Follow the installation instructions for the deep learning library you're using, [PyTorch](https://pytorch.org/get-started/locally/) or [Flax](https://flax.readthedocs.io/en/latest/).
🤗 Diffusers is tested on Python 3.8+, PyTorch 1.7.0+, and Flax. Follow the installation instructions below for the deep learning library you are using:
Create a [virtual environment](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) for easier management of separate projects and to avoid compatibility issues between dependencies. Use [uv](https://docs.astral.sh/uv/), a Rust-based Python package and project manager, to create a virtual environment and install Diffusers.
- [PyTorch](https://pytorch.org/get-started/locally/) installation instructions
- [Flax](https://flax.readthedocs.io/en/latest/) installation instructions
## Install with pip
You should install 🤗 Diffusers in a [virtual environment](https://docs.python.org/3/library/venv.html).
If you're unfamiliar with Python virtual environments, take a look at this [guide](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
A virtual environment makes it easier to manage different projects and avoid compatibility issues between dependencies.
Create a virtual environment with Python or [uv](https://docs.astral.sh/uv/) (refer to [Installation](https://docs.astral.sh/uv/getting-started/installation/) for installation instructions), a fast Rust-based Python package and project manager.
<hfoptions id="install">
<hfoption id="uv">
```bash
uv venv my-env
source my-env/bin/activate
```
Install Diffusers with one of the following methods.
<hfoptions id="install">
<hfoption id="pip">
PyTorch only supports Python 3.8 - 3.11 on Windows.
</hfoption>
<hfoption id="Python">
```bash
uv pip install diffusers["torch"] transformers
python -m venv my-env
source my-env/bin/activate
```
Use the command below for Flax.
</hfoption>
</hfoptions>
You should also install 🤗 Transformers because 🤗 Diffusers relies on its models.
<frameworkcontent>
<pt>
PyTorch only supports Python 3.8 - 3.11 on Windows. Install Diffusers with uv.
```bash
uv install diffusers["torch"] transformers
```
You can also install Diffusers with pip.
```bash
pip install diffusers["torch"] transformers
```
</pt>
<jax>
Install Diffusers with uv.
```bash
uv pip install diffusers["flax"] transformers
```
</hfoption>
<hfoption id="conda">
You can also install Diffusers with pip.
```bash
pip install diffusers["flax"] transformers
```
</jax>
</frameworkcontent>
## Install with conda
After activating your virtual environment, with `conda` (maintained by the community):
```bash
conda install -c conda-forge diffusers
```
</hfoption>
<hfoption id="source">
## Install from source
A source install installs the `main` version instead of the latest `stable` version. The `main` version is useful for staying updated with the latest changes but it may not always be stable. If you run into a problem, open an [Issue](https://github.com/huggingface/diffusers/issues/new/choose) and we will try to resolve it as soon as possible.
Before installing 🤗 Diffusers from source, make sure you have PyTorch and 🤗 Accelerate installed.
Make sure [Accelerate](https://huggingface.co/docs/accelerate/index) is installed.
To install 🤗 Accelerate:
```bash
uv pip install accelerate
pip install accelerate
```
Install Diffusers from source with the command below.
Then install 🤗 Diffusers from source:
```bash
uv pip install git+https://github.com/huggingface/diffusers
pip install git+https://github.com/huggingface/diffusers
```
</hfoption>
</hfoptions>
This command installs the bleeding edge `main` version rather than the latest `stable` version.
The `main` version is useful for staying up-to-date with the latest developments.
For instance, if a bug has been fixed since the last official release but a new release hasn't been rolled out yet.
However, this means the `main` version may not always be stable.
We strive to keep the `main` version operational, and most issues are usually resolved within a few hours or a day.
If you run into a problem, please open an [Issue](https://github.com/huggingface/diffusers/issues/new/choose) so we can fix it even sooner!
## Editable install
An editable install is recommended for development workflows or if you're using the `main` version of the source code. A special link is created between the cloned repository and the Python library paths. This avoids reinstalling a package after every change.
You will need an editable install if you'd like to:
Clone the repository and install Diffusers with the following commands.
* Use the `main` version of the source code.
* Contribute to 🤗 Diffusers and need to test changes in the code.
<hfoptions id="editable">
<hfoption id="PyTorch">
Clone the repository and install 🤗 Diffusers with the following commands:
```bash
git clone https://github.com/huggingface/diffusers.git
cd diffusers
uv pip install -e ".[torch]"
```
</hfoption>
<hfoption id="Flax">
<frameworkcontent>
<pt>
```bash
git clone https://github.com/huggingface/diffusers.git
cd diffusers
uv pip install -e ".[flax]"
pip install -e ".[torch]"
```
</pt>
<jax>
```bash
pip install -e ".[flax]"
```
</jax>
</frameworkcontent>
</hfoption>
</hfoptions>
These commands will link the folder you cloned the repository to and your Python library paths.
Python will now look inside the folder you cloned to in addition to the normal library paths.
For example, if your Python packages are typically installed in `~/anaconda3/envs/main/lib/python3.10/site-packages/`, Python will also search the `~/diffusers/` folder you cloned to.
> [!WARNING]
> You must keep the `diffusers` folder if you want to keep using the library with the editable install.
<Tip warning={true}>
Update your cloned repository to the latest version of Diffusers with the command below.
You must keep the `diffusers` folder if you want to keep using the library.
</Tip>
Now you can easily update your clone to the latest version of 🤗 Diffusers with the following command:
```bash
cd ~/diffusers/
git pull
```
Your Python environment will find the `main` version of 🤗 Diffusers on the next run.
## Cache
Model weights and files are downloaded from the Hub to a cache, which is usually your home directory. Change the cache location with the [HF_HOME](https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables#hfhome) or [HF_HUB_CACHE](https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables#hfhubcache) environment variables or configuring the `cache_dir` parameter in methods like [`~DiffusionPipeline.from_pretrained`].
Model weights and files are downloaded from the Hub to a cache which is usually your home directory. You can change the cache location by specifying the `HF_HOME` or `HUGGINFACE_HUB_CACHE` environment variables or configuring the `cache_dir` parameter in methods like [`~DiffusionPipeline.from_pretrained`].
<hfoptions id="cache">
<hfoption id="env variable">
```bash
export HF_HOME="/path/to/your/cache"
export HF_HUB_CACHE="/path/to/your/hub/cache"
```
</hfoption>
<hfoption id="from_pretrained">
```py
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
cache_dir="/path/to/your/cache"
)
```
</hfoption>
</hfoptions>
Cached files allow you to use Diffusers offline. Set the [HF_HUB_OFFLINE](https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables#hfhuboffline) environment variable to `1` to prevent Diffusers from connecting to the internet.
Cached files allow you to run 🤗 Diffusers offline. To prevent 🤗 Diffusers from connecting to the internet, set the `HF_HUB_OFFLINE` environment variable to `1` and 🤗 Diffusers will only load previously downloaded files in the cache.
```shell
export HF_HUB_OFFLINE=1
```
For more details about managing and cleaning the cache, take a look at the [Understand caching](https://huggingface.co/docs/huggingface_hub/guides/manage-cache) guide.
For more details about managing and cleaning the cache, take a look at the [caching](https://huggingface.co/docs/huggingface_hub/guides/manage-cache) guide.
## Telemetry logging
Diffusers gathers telemetry information during [`~DiffusionPipeline.from_pretrained`] requests.
The data gathered includes the Diffusers and PyTorch/Flax version, the requested model or pipeline class,
and the path to a pretrained checkpoint if it is hosted on the Hub.
Our library gathers telemetry information during [`~DiffusionPipeline.from_pretrained`] requests.
The data gathered includes the version of 🤗 Diffusers and PyTorch/Flax, the requested model or pipeline class,
and the path to a pretrained checkpoint if it is hosted on the Hugging Face Hub.
This usage data helps us debug issues and prioritize new features.
Telemetry is only sent when loading models and pipelines from the Hub,
and it is not collected if you're loading local files.
Opt-out and disable telemetry collection with the [HF_HUB_DISABLE_TELEMETRY](https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables#hfhubdisabletelemetry) environment variable.
We understand that not everyone wants to share additional information,and we respect your privacy.
You can disable telemetry collection by setting the `HF_HUB_DISABLE_TELEMETRY` environment variable from your terminal:
<hfoptions id="telemetry">
<hfoption id="Linux/macOS">
On Linux/MacOS:
```bash
export HF_HUB_DISABLE_TELEMETRY=1
```
</hfoption>
<hfoption id="Windows">
On Windows:
```bash
set HF_HUB_DISABLE_TELEMETRY=1
```
</hfoption>
</hfoptions>
-10
View File
@@ -53,16 +53,6 @@ image = pipe(prompt, generator=torch.manual_seed(0)).images[0]
image.save("flux-gguf.png")
```
## Using Optimized CUDA Kernels with GGUF
Optimized CUDA kernels can accelerate GGUF quantized model inference by approximately 10%. This functionality requires a compatible GPU with `torch.cuda.get_device_capability` greater than 7 and the kernels library:
```shell
pip install -U kernels
```
Once installed, set `DIFFUSERS_GGUF_CUDA_KERNELS=true` to use optimized kernels when available. Note that CUDA kernels may introduce minor numerical differences compared to the original GGUF implementation, potentially causing subtle visual variations in generated images. To disable CUDA kernel usage, set the environment variable `DIFFUSERS_GGUF_CUDA_KERNELS=false`.
## Supported Quantization Types
- BF16
-136
View File
@@ -1,136 +0,0 @@
# DreamBooth training example for Qwen Image
[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject.
The `train_dreambooth_lora_qwen_image.py` script shows how to implement the training procedure with [LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) and adapt it for [Qwen Image](https://huggingface.co/Qwen/Qwen-Image).
This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
## Running locally with PyTorch
### Installing the dependencies
Before running the scripts, make sure to install the library's training dependencies:
**Important**
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .
```
Then cd in the `examples/dreambooth` folder and run
```bash
pip install -r requirements_sana.txt
```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash
accelerate config
```
Or for a default accelerate configuration without answering questions about your environment
```bash
accelerate config default
```
Or if your environment doesn't support an interactive shell (e.g., a notebook)
```python
from accelerate.utils import write_basic_config
write_basic_config()
```
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.14.0` installed in your environment.
### Dog toy example
Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
Let's first download it locally:
```python
from huggingface_hub import snapshot_download
local_dir = "./dog"
snapshot_download(
"diffusers/dog-example",
local_dir=local_dir, repo_type="dataset",
ignore_patterns=".gitattributes",
)
```
This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
Now, we can launch training using:
```bash
export MODEL_NAME="Qwen/Qwen-Image"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-sana-lora"
accelerate launch train_dreambooth_lora_sana.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision="bf16" \
--instance_prompt="a photo of sks dog" \
--resolution=1024 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--use_8bit_adam \
--learning_rate=2e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25 \
--seed="0" \
--push_to_hub
```
For using `push_to_hub`, make you're logged into your Hugging Face account:
```bash
hf auth login
```
To better track our training experiments, we're using the following flags in the command above:
* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login <your_api_key>` before training if you haven't done it before.
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
## Notes
Additionally, we welcome you to explore the following CLI arguments:
* `--lora_layers`: The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - "to_k,to_q,to_v" will result in lora training of attention layers only.
* `--max_sequence_length`: Maximum sequence length to use for text embeddings.
We provide several options for optimizing memory optimization:
* `--offload`: When enabled, we will offload the text encoder and VAE to CPU, when they are not used.
* `cache_latents`: When enabled, we will pre-compute the latents from the input images with the VAE and remove the VAE from memory once done.
* `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library.
Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwenimage) of the `QwenImagePipeline` to know more about the models available under the SANA family and their preferred dtypes during inference.
## Using quantization
You can quantize the base model with [`bitsandbytes`](https://huggingface.co/docs/bitsandbytes/index) to reduce memory usage. To do so, pass a JSON file path to `--bnb_quantization_config_path`. This file should hold the configuration to initialize `BitsAndBytesConfig`. Below is an example JSON file:
```json
{
"load_in_4bit": true,
"bnb_4bit_quant_type": "nf4"
}
```
@@ -1,248 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import os
import sys
import tempfile
import safetensors
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
class DreamBoothLoRAQwenImage(ExamplesTestsAccelerate):
instance_data_dir = "docs/source/en/imgs"
instance_prompt = "photo"
pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-pipe"
script_path = "examples/dreambooth/train_dreambooth_lora_qwen_image.py"
transformer_layer_type = "transformer_blocks.0.attn.to_k"
def test_dreambooth_lora_qwen(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_latent_caching(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--cache_latents
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_layers(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--cache_latents
--learning_rate 5.0e-04
--scale_lr
--lora_layers {self.transformer_layer_type}
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names. In this test, we only params of
# transformer.transformer_blocks.0.attn.to_k should be in the state dict
starts_with_transformer = all(
key.startswith(f"transformer.{self.transformer_layer_type}") for key in lora_state_dict.keys()
)
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_qwen_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=6
--checkpoints_total_limit=2
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-4", "checkpoint-6"},
)
def test_dreambooth_lora_qwen_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=4
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
resume_run_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=8
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-4
--checkpoints_total_limit=2
""".split()
run_command(self._launch_args + resume_run_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
def test_dreambooth_lora_with_metadata(self):
# Use a `lora_alpha` that is different from `rank`.
lora_alpha = 8
rank = 4
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--lora_alpha={lora_alpha}
--rank={rank}
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
self.assertTrue(os.path.isfile(state_dict_file))
# Check if the metadata was properly serialized.
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
metadata = f.metadata() or {}
metadata.pop("format", None)
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
if raw:
raw = json.loads(raw)
loaded_lora_alpha = raw["transformer.lora_alpha"]
self.assertTrue(loaded_lora_alpha == lora_alpha)
loaded_lora_rank = raw["transformer.r"]
self.assertTrue(loaded_lora_rank == rank)
File diff suppressed because it is too large Load Diff
+108 -80
View File
@@ -95,7 +95,7 @@ class ModuleGroup:
self.offload_to_disk_path = offload_to_disk_path
self._is_offloaded_to_disk = False
if self.offload_to_disk_path is not None:
if self.offload_to_disk_path:
# Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well.
self.group_id = group_id if group_id is not None else str(id(self))
short_hash = _compute_group_hash(self.group_id)
@@ -115,12 +115,6 @@ class ModuleGroup:
else:
self.cpu_param_dict = self._init_cpu_param_dict()
self._torch_accelerator_module = (
getattr(torch, torch.accelerator.current_accelerator().type)
if hasattr(torch, "accelerator")
else torch.cuda
)
def _init_cpu_param_dict(self):
cpu_param_dict = {}
if self.stream is None:
@@ -144,76 +138,112 @@ class ModuleGroup:
@contextmanager
def _pinned_memory_tensors(self):
pinned_dict = {}
try:
pinned_dict = {
param: tensor.pin_memory() if not tensor.is_pinned() else tensor
for param, tensor in self.cpu_param_dict.items()
}
for param, tensor in self.cpu_param_dict.items():
if not tensor.is_pinned():
pinned_dict[param] = tensor.pin_memory()
else:
pinned_dict[param] = tensor
yield pinned_dict
finally:
pinned_dict = None
def _transfer_tensor_to_device(self, tensor, source_tensor):
def _transfer_tensor_to_device(self, tensor, source_tensor, current_stream=None):
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
tensor.data.record_stream(self._torch_accelerator_module.current_stream())
if self.record_stream and current_stream is not None:
tensor.data.record_stream(current_stream)
def _process_tensors_from_modules(self, pinned_memory=None):
def _process_tensors_from_modules(self, pinned_memory=None, current_stream=None):
for group_module in self.modules:
for param in group_module.parameters():
source = pinned_memory[param] if pinned_memory else param.data
self._transfer_tensor_to_device(param, source)
self._transfer_tensor_to_device(param, source, current_stream)
for buffer in group_module.buffers():
source = pinned_memory[buffer] if pinned_memory else buffer.data
self._transfer_tensor_to_device(buffer, source)
self._transfer_tensor_to_device(buffer, source, current_stream)
for param in self.parameters:
source = pinned_memory[param] if pinned_memory else param.data
self._transfer_tensor_to_device(param, source)
self._transfer_tensor_to_device(param, source, current_stream)
for buffer in self.buffers:
source = pinned_memory[buffer] if pinned_memory else buffer.data
self._transfer_tensor_to_device(buffer, source)
self._transfer_tensor_to_device(buffer, source, current_stream)
def _onload_from_disk(self, current_stream):
if self.stream is not None:
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
for key, tensor_obj in self.key_to_tensor.items():
self.cpu_param_dict[tensor_obj] = loaded_cpu_tensors[key]
with self._pinned_memory_tensors() as pinned_memory:
for key, tensor_obj in self.key_to_tensor.items():
self._transfer_tensor_to_device(tensor_obj, pinned_memory[tensor_obj], current_stream)
self.cpu_param_dict.clear()
else:
onload_device = (
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
)
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
for key, tensor_obj in self.key_to_tensor.items():
tensor_obj.data = loaded_tensors[key]
def _onload_from_memory(self, current_stream):
if self.stream is not None:
with self._pinned_memory_tensors() as pinned_memory:
self._process_tensors_from_modules(pinned_memory, current_stream)
else:
self._process_tensors_from_modules(None, current_stream)
@torch.compiler.disable()
def onload_(self):
torch_accelerator_module = (
getattr(torch, torch.accelerator.current_accelerator().type)
if hasattr(torch, "accelerator")
else torch.cuda
)
context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream)
current_stream = torch_accelerator_module.current_stream() if self.record_stream else None
if self.offload_to_disk_path:
if self.stream is not None:
# Wait for previous Host->Device transfer to complete
self.stream.synchronize()
with context:
if self.stream is not None:
# Load to CPU, pin, and async copy to device for overlapping transfer and compute
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
for key, tensor_obj in self.key_to_tensor.items():
pinned_tensor = loaded_cpu_tensors[key].pin_memory()
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
tensor_obj.data.record_stream(current_stream)
else:
# Load directly to the target device (synchronous)
onload_device = (
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
)
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
for key, tensor_obj in self.key_to_tensor.items():
tensor_obj.data = loaded_tensors[key]
return
def _onload_from_disk(self):
if self.stream is not None:
# Wait for previous Host->Device transfer to complete
self.stream.synchronize()
context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
current_stream = self._torch_accelerator_module.current_stream() if self.record_stream else None
with context:
# Load to CPU (if using streams) or directly to target device, pin, and async copy to device
device = str(self.onload_device) if self.stream is None else "cpu"
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device)
if self.stream is not None:
for key, tensor_obj in self.key_to_tensor.items():
pinned_tensor = loaded_tensors[key].pin_memory()
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
tensor_obj.data.record_stream(current_stream)
if self.offload_to_disk_path:
self._onload_from_disk(current_stream)
else:
onload_device = (
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
)
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
for key, tensor_obj in self.key_to_tensor.items():
tensor_obj.data = loaded_tensors[key]
def _onload_from_memory(self):
if self.stream is not None:
# Wait for previous Host->Device transfer to complete
self.stream.synchronize()
context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
with context:
if self.stream is not None:
with self._pinned_memory_tensors() as pinned_memory:
self._process_tensors_from_modules(pinned_memory)
else:
self._process_tensors_from_modules(None)
self._onload_from_memory(current_stream)
def _offload_to_disk(self):
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
@@ -234,10 +264,14 @@ class ModuleGroup:
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
def _offload_to_memory(self):
torch_accelerator_module = (
getattr(torch, torch.accelerator.current_accelerator().type)
if hasattr(torch, "accelerator")
else torch.cuda
)
if self.stream is not None:
if not self.record_stream:
self._torch_accelerator_module.current_stream().synchronize()
torch_accelerator_module.current_stream().synchronize()
for group_module in self.modules:
for param in group_module.parameters():
param.data = self.cpu_param_dict[param]
@@ -248,23 +282,15 @@ class ModuleGroup:
else:
for group_module in self.modules:
group_module.to(self.offload_device, non_blocking=False)
group_module.to(self.offload_device, non_blocking=self.non_blocking)
for param in self.parameters:
param.data = param.data.to(self.offload_device, non_blocking=False)
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
for buffer in self.buffers:
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
@torch.compiler.disable()
def onload_(self):
r"""Onloads the group of parameters to the onload_device."""
if self.offload_to_disk_path is not None:
self._onload_from_disk()
else:
self._onload_from_memory()
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
@torch.compiler.disable()
def offload_(self):
r"""Offloads the group of parameters to the offload_device."""
r"""Offloads the group of modules to the offload_device."""
if self.offload_to_disk_path:
self._offload_to_disk()
else:
@@ -281,9 +307,11 @@ class GroupOffloadingHook(ModelHook):
_is_stateful = False
def __init__(self, group: ModuleGroup, *, config: GroupOffloadingConfig) -> None:
def __init__(
self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None, *, config: GroupOffloadingConfig
) -> None:
self.group = group
self.next_group: Optional[ModuleGroup] = None
self.next_group = next_group
self.config = config
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
@@ -431,8 +459,8 @@ class LayerExecutionTrackerHook(ModelHook):
def apply_group_offloading(
module: torch.nn.Module,
onload_device: Union[str, torch.device],
offload_device: Union[str, torch.device] = torch.device("cpu"),
onload_device: torch.device,
offload_device: torch.device = torch.device("cpu"),
offload_type: Union[str, GroupOffloadingType] = "block_level",
num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False,
@@ -518,8 +546,6 @@ def apply_group_offloading(
```
"""
onload_device = torch.device(onload_device) if isinstance(onload_device, str) else onload_device
offload_device = torch.device(offload_device) if isinstance(offload_device, str) else offload_device
offload_type = GroupOffloadingType(offload_type)
stream = None
@@ -607,7 +633,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
# Apply group offloading hooks to the module groups
for i, group in enumerate(matched_module_groups):
for group_module in group.modules:
_apply_group_offloading_hook(group_module, group, config=config)
_apply_group_offloading_hook(group_module, group, None, config=config)
# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
# when the forward pass of this module is called. This is because the top-level module is not
@@ -636,9 +662,9 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
group_id=f"{module.__class__.__name__}_unmatched_group",
)
if config.stream is None:
_apply_group_offloading_hook(module, unmatched_group, config=config)
_apply_group_offloading_hook(module, unmatched_group, None, config=config)
else:
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
_apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
@@ -667,7 +693,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
onload_self=True,
group_id=name,
)
_apply_group_offloading_hook(submodule, group, config=config)
_apply_group_offloading_hook(submodule, group, None, config=config)
modules_with_group_offloading.add(name)
# Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
@@ -714,7 +740,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
onload_self=True,
group_id=name,
)
_apply_group_offloading_hook(parent_module, group, config=config)
_apply_group_offloading_hook(parent_module, group, None, config=config)
if config.stream is not None:
# When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
@@ -736,12 +762,13 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
onload_self=True,
group_id=_GROUP_ID_LAZY_LEAF,
)
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
_apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
def _apply_group_offloading_hook(
module: torch.nn.Module,
group: ModuleGroup,
next_group: Optional[ModuleGroup] = None,
*,
config: GroupOffloadingConfig,
) -> None:
@@ -750,13 +777,14 @@ def _apply_group_offloading_hook(
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
if registry.get_hook(_GROUP_OFFLOADING) is None:
hook = GroupOffloadingHook(group, config=config)
hook = GroupOffloadingHook(group, next_group, config=config)
registry.register_hook(hook, _GROUP_OFFLOADING)
def _apply_lazy_group_offloading_hook(
module: torch.nn.Module,
group: ModuleGroup,
next_group: Optional[ModuleGroup] = None,
*,
config: GroupOffloadingConfig,
) -> None:
@@ -765,7 +793,7 @@ def _apply_lazy_group_offloading_hook(
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
if registry.get_hook(_GROUP_OFFLOADING) is None:
hook = GroupOffloadingHook(group, config=config)
hook = GroupOffloadingHook(group, next_group, config=config)
registry.register_hook(hook, _GROUP_OFFLOADING)
lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook()
-2
View File
@@ -79,7 +79,6 @@ if is_torch_available():
"WanLoraLoaderMixin",
"HiDreamImageLoraLoaderMixin",
"SkyReelsV2LoraLoaderMixin",
"QwenImageLoraLoaderMixin",
]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = [
@@ -119,7 +118,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LTXVideoLoraLoaderMixin,
Lumina2LoraLoaderMixin,
Mochi1LoraLoaderMixin,
QwenImageLoraLoaderMixin,
SanaLoraLoaderMixin,
SD3LoraLoaderMixin,
SkyReelsV2LoraLoaderMixin,
-342
View File
@@ -6538,348 +6538,6 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
super().unfuse_lora(components=components, **kwargs)
class QwenImageLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`QwenImageTransformer2DModel`]. Specific to [`QwenImagePipeline`].
"""
_lora_loadable_modules = ["transformer"]
transformer_name = TRANSFORMER_NAME
@classmethod
@validate_hf_hub_args
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
Return state dict for lora weights and the network alphas.
<Tip warning={true}>
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
This function is experimental and might change in the future.
</Tip>
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
return_lora_metadata (`bool`, *optional*, defaults to False):
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
local_files_only=local_files_only,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
)
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
hotswap: bool = False,
**kwargs,
):
"""
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
dict is loaded into `self.transformer`.
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
kwargs (`dict`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# if a dict is passed, copy it instead of modifying it inplace
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->QwenImageTransformer2DModel
def load_lora_into_transformer(
cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
Parameters:
state_dict (`dict`):
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
encoder lora layers.
transformer (`QwenImageTransformer2DModel`):
The Transformer model to load the LoRA layers into.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# Load the layers corresponding to transformer.
logger.info(f"Loading {cls.transformer_name}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=None,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
Save the LoRA parameters corresponding to the transformer.
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to save LoRA parameters to. Will be created if it doesn't exist.
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `transformer`.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful during distributed training and you
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
process to avoid race conditions.
save_function (`Callable`):
The function to use to save the state dictionary. Useful during distributed training when you need to
replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
"""
state_dict = {}
lora_adapter_metadata = {}
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
def fuse_lora(
self,
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
<Tip warning={true}>
This is an experimental API.
</Tip>
Args:
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
lora_scale (`float`, defaults to 1.0):
Controls how much to influence the outputs with the LoRA parameters.
safe_fusing (`bool`, defaults to `False`):
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
adapter_names (`List[str]`, *optional*):
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
Example:
```py
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
pipeline.fuse_lora(lora_scale=0.7)
```
"""
super().fuse_lora(
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
<Tip warning={true}>
This is an experimental API.
</Tip>
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components, **kwargs)
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
def __init__(self, *args, **kwargs):
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
-1
View File
@@ -61,7 +61,6 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
"HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights,
"WanVACETransformer3DModel": lambda model_cls, weights: weights,
"ChromaTransformer2DModel": lambda model_cls, weights: weights,
"QwenImageTransformer2DModel": lambda model_cls, weights: weights,
}
+2 -5
View File
@@ -30,7 +30,6 @@ from huggingface_hub import DDUFEntry
from huggingface_hub.utils import EntryNotFoundError
from ..quantizers import DiffusersQuantizer
from ..quantizers.quantization_config import QuantizationMethod
from ..utils import (
GGUF_FILE_EXTENSION,
SAFE_WEIGHTS_INDEX_NAME,
@@ -232,7 +231,6 @@ def load_model_dict_into_meta(
"""
is_quantized = hf_quantizer is not None
is_higgs = is_quantized and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HIGGS
empty_state_dict = model.state_dict()
for param_name, param in state_dict.items():
@@ -282,8 +280,7 @@ def load_model_dict_into_meta(
# bnb params are flattened.
# gguf quants have a different shape based on the type of quantization applied
# higgs quants repack the weights so they will have different shapes
if empty_state_dict[param_name].shape != param.shape and not is_higgs:
if empty_state_dict[param_name].shape != param.shape:
if (
is_quantized
and hf_quantizer.pre_quantized
@@ -307,7 +304,7 @@ def load_model_dict_into_meta(
hf_quantizer.create_quantized_param(
model, param, param_name, param_device, state_dict, unexpected_keys, dtype=dtype
)
elif hf_quantizer is not None:
else:
set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs)
return offload_index, state_dict_index
@@ -312,14 +312,15 @@ class AudioLDM2Pipeline(DiffusionPipeline):
The sequence of generated hidden-states.
"""
cache_position_kwargs = {}
if is_transformers_version("<", "4.52.1"):
if is_transformers_version("<", "4.52.0.dev0"):
cache_position_kwargs["input_ids"] = inputs_embeds
cache_position_kwargs["model_kwargs"] = model_kwargs
else:
cache_position_kwargs["seq_length"] = inputs_embeds.shape[0]
cache_position_kwargs["device"] = (
self.language_model.device if getattr(self, "language_model", None) is not None else self.device
)
cache_position_kwargs["model_kwargs"] = model_kwargs
cache_position_kwargs["model_kwargs"] = model_kwargs
max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens
model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs)
@@ -20,7 +20,6 @@ import torch
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
from ...image_processor import VaeImageProcessor
from ...loaders import QwenImageLoraLoaderMixin
from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
@@ -45,12 +44,12 @@ EXAMPLE_DOC_STRING = """
>>> import torch
>>> from diffusers import QwenImagePipeline
>>> pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16)
>>> pipe = QwenImagePipeline.from_pretrained("Qwen/QwenImage-20B", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> prompt = "A cat holding a sign that says hello world"
>>> # Depending on the variant being used, the pipeline call will slightly vary.
>>> # Refer to the pipeline documentation for more details.
>>> image = pipe(prompt, num_inference_steps=50).images[0]
>>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
>>> image.save("qwenimage.png")
```
"""
@@ -129,7 +128,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
class QwenImagePipeline(DiffusionPipeline):
r"""
The QwenImage pipeline for text-to-image generation.
@@ -636,11 +635,6 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
if self.attention_kwargs is None:
self._attention_kwargs = {}
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
negative_txt_seq_lens = (
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
)
# 6. Denoising loop
self.scheduler.set_begin_index(0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -659,7 +653,7 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
encoder_hidden_states_mask=prompt_embeds_mask,
encoder_hidden_states=prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=txt_seq_lens,
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
@@ -673,7 +667,7 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=negative_txt_seq_lens,
txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(),
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
-4
View File
@@ -21,11 +21,9 @@ from typing import Dict, Optional, Union
from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer
from .gguf import GGUFQuantizer
from .higgs import HiggsQuantizer
from .quantization_config import (
BitsAndBytesConfig,
GGUFQuantizationConfig,
HiggsConfig,
QuantizationConfigMixin,
QuantizationMethod,
QuantoConfig,
@@ -41,7 +39,6 @@ AUTO_QUANTIZER_MAPPING = {
"gguf": GGUFQuantizer,
"quanto": QuantoQuantizer,
"torchao": TorchAoHfQuantizer,
"higgs": HiggsQuantizer,
}
AUTO_QUANTIZATION_CONFIG_MAPPING = {
@@ -50,7 +47,6 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
"gguf": GGUFQuantizationConfig,
"quanto": QuantoConfig,
"torchao": TorchAoConfig,
"higgs": HiggsConfig,
}
+3 -92
View File
@@ -12,15 +12,15 @@
# # See the License for the specific language governing permissions and
# # limitations under the License.
import inspect
import os
from contextlib import nullcontext
import gguf
import torch
import torch.nn as nn
from ...utils import is_accelerate_available, is_kernels_available
from ...utils import is_accelerate_available
if is_accelerate_available():
@@ -29,82 +29,6 @@ if is_accelerate_available():
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
can_use_cuda_kernels = (
os.getenv("DIFFUSERS_GGUF_CUDA_KERNELS", "false").lower() in ["1", "true", "yes"]
and torch.cuda.is_available()
and torch.cuda.get_device_capability()[0] >= 7
)
if can_use_cuda_kernels and is_kernels_available():
from kernels import get_kernel
ops = get_kernel("Isotr0py/ggml")
else:
ops = None
UNQUANTIZED_TYPES = {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.BF16}
STANDARD_QUANT_TYPES = {
gguf.GGMLQuantizationType.Q4_0,
gguf.GGMLQuantizationType.Q4_1,
gguf.GGMLQuantizationType.Q5_0,
gguf.GGMLQuantizationType.Q5_1,
gguf.GGMLQuantizationType.Q8_0,
gguf.GGMLQuantizationType.Q8_1,
}
KQUANT_TYPES = {
gguf.GGMLQuantizationType.Q2_K,
gguf.GGMLQuantizationType.Q3_K,
gguf.GGMLQuantizationType.Q4_K,
gguf.GGMLQuantizationType.Q5_K,
gguf.GGMLQuantizationType.Q6_K,
}
IMATRIX_QUANT_TYPES = {
gguf.GGMLQuantizationType.IQ1_M,
gguf.GGMLQuantizationType.IQ1_S,
gguf.GGMLQuantizationType.IQ2_XXS,
gguf.GGMLQuantizationType.IQ2_XS,
gguf.GGMLQuantizationType.IQ2_S,
gguf.GGMLQuantizationType.IQ3_XXS,
gguf.GGMLQuantizationType.IQ3_S,
gguf.GGMLQuantizationType.IQ4_XS,
gguf.GGMLQuantizationType.IQ4_NL,
}
# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization.
# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add
# MMQ kernel for I-Matrix quantization.
DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: int) -> torch.Tensor:
# there is no need to call any kernel for fp16/bf16
if qweight_type in UNQUANTIZED_TYPES:
return x @ qweight.T
# TODO(Isotr0py): GGUF's MMQ and MMVQ implementation are designed for
# contiguous batching and inefficient with diffusers' batching,
# so we disabled it now.
# elif qweight_type in MMVQ_QUANT_TYPES:
# y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
# elif qweight_type in MMQ_QUANT_TYPES:
# y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
# If there is no available MMQ kernel, fallback to dequantize
if qweight_type in DEQUANT_TYPES:
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
weight = ops.ggml_dequantize(qweight, qweight_type, *shape)
y = x @ weight.to(x.dtype).T
else:
# Raise an error if the quantization type is not supported.
# Might be useful if llama.cpp adds a new quantization type.
# Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type.
qweight_type = gguf.GGMLQuantizationType(qweight_type)
raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}")
return y.as_tensor()
# Copied from diffusers.quantizers.bitsandbytes.utils._create_accelerate_new_hook
def _create_accelerate_new_hook(old_hook):
r"""
@@ -527,24 +451,11 @@ class GGUFLinear(nn.Linear):
) -> None:
super().__init__(in_features, out_features, bias, device)
self.compute_dtype = compute_dtype
self.device = device
def forward(self, inputs: torch.Tensor):
if ops is not None and self.weight.is_cuda and inputs.is_cuda:
return self.forward_cuda(inputs)
return self.forward_native(inputs)
def forward_native(self, inputs: torch.Tensor):
def forward(self, inputs):
weight = dequantize_gguf_tensor(self.weight)
weight = weight.to(self.compute_dtype)
bias = self.bias.to(self.compute_dtype) if self.bias is not None else None
output = torch.nn.functional.linear(inputs, weight, bias)
return output
def forward_cuda(self, inputs: torch.Tensor):
quant_type = self.weight.quant_type
output = _fused_mul_mat_gguf(inputs.to(self.compute_dtype), self.weight, quant_type)
if self.bias is not None:
output += self.bias.to(self.compute_dtype)
return output
@@ -1 +0,0 @@
from .higgs_quantizer import HiggsQuantizer
@@ -1,205 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Adapted from
https://github.com/huggingface/transformers/blob/d3d835d4fc145e5062d2153ac23ccd4b3e2c2cbd/src/transformers/quantizers/quantizer_higgs.py
"""
from typing import TYPE_CHECKING, Any, Optional
from ...utils import get_module_from_name
from ..base import DiffusersQuantizer
if TYPE_CHECKING:
from ...models.modeling_utils import ModelMixin
from ...utils import is_accelerate_available, is_torch_available, logging
from ...utils.logging import tqdm
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
class HiggsQuantizer(DiffusersQuantizer):
"""
Quantizer of the HIGGS method. Enables the loading of prequantized models and in-flight quantization of
full-precision models.
"""
requires_calibration = False
requires_parameters_quantization = True
required_packages = ["flute-kernel", "fast_hadamard_transform"]
def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)
self.quantization_config = quantization_config
def validate_environment(self, device_map, **kwargs):
if not torch.cuda.is_available():
raise NotImplementedError("HIGGS quantization is only supported on GPU. Please use a different quantizer.")
if not is_accelerate_available():
raise ImportError("Using `higgs` quantization requires Accelerate: `pip install accelerate`")
# TODO: enable this.
# if not is_flute_available():
# raise ImportError("Using `higgs` quantization requires FLUTE: `pip install flute-kernel>=0.3.0`")
# if not is_hadamard_available():
# raise ImportError(
# "Using `higgs` quantization requires fast_hadamard_transform: `pip install fast_hadamard_transform`"
# )
if device_map is None:
raise ValueError(
"You are attempting to load a HIGGS model without setting device_map."
" Please set device_map comprised of 'cuda' devices."
)
elif isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()):
raise ValueError(
"You are attempting to load a HIGGS model with a device_map that contains a CPU or disk device."
" This is not supported. Please remove the CPU or disk device from the device_map."
)
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
if torch_dtype is None:
logger.info("`torch_dtype` is None. Setting `torch_dtype=torch.float16` for FLUTE compatibility.")
torch_dtype = torch.float16
elif torch_dtype != torch.float16 and torch_dtype != torch.bfloat16:
raise ValueError(
f"Invalid `torch_dtype` {torch_dtype}. HIGGS quantization only supports `torch_dtype=torch.float16` or `torch_dtype=torch.bfloat16`."
)
return torch_dtype
def create_quantized_param(
self,
model: "ModelMixin",
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
state_dict: dict[str, Any],
unexpected_keys: Optional[list[str]] = None,
):
from .utils import quantize_with_higgs
"""
Quantizes weights into weight and weight_scale
"""
flute_dict = quantize_with_higgs(
param_value.to(target_device),
self.quantization_config.bits,
self.quantization_config.p,
self.quantization_config.group_size,
self.quantization_config.hadamard_size,
)
del param_value
module, _ = get_module_from_name(model, param_name)
module_name = ".".join(param_name.split(".")[:-1])
for key, value in flute_dict.items():
if key in module._parameters:
module._parameters[key] = torch.nn.Parameter(value, requires_grad=False)
elif key in module._buffers:
module._buffers[key] = torch.nn.Buffer(value)
elif key == "tune_metadata":
module.tune_metadata = value
self.quantization_config.tune_metadata[module_name] = value.to_dict()
else:
raise ValueError(f"Unexpected key {key} in module {module}")
if unexpected_keys is not None and param_name in unexpected_keys:
unexpected_keys.remove(param_name)
def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
from .utils import HiggsLinear
higgs_names = {name for name, module in model.named_modules() if isinstance(module, HiggsLinear)}
def should_update(key: str) -> bool:
if key.endswith(".weight") or key.endswith(".bias"):
return False
full_key = f"{prefix}.{key}"
return any(name in key or name in full_key for name in higgs_names)
return [key for key in missing_keys if not should_update(key)]
@property
def is_trainable(self):
return False
def is_serializable(self):
return True
def check_quantized_param(
self,
model: "ModelMixin",
param_value: "torch.Tensor",
param_name: str,
state_dict: dict[str, Any],
**kwargs,
) -> bool:
from .utils import HiggsLinear
module, tensor_name = get_module_from_name(model, param_name)
if isinstance(module, HiggsLinear) and tensor_name == "weight" and param_value.dtype != torch.int16:
# Only quantize weights of HiggsLinear modules that are not already quantized
return True
else:
return False
def _process_model_before_weight_loading(
self,
model: "ModelMixin",
**kwargs,
):
from .utils import replace_with_higgs_linear
replace_with_higgs_linear(model, quantization_config=self.quantization_config)
model.config.quantization_config = self.quantization_config
def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs):
from flute.tune import TuneMetaData, maybe_tune_and_repack
from flute.utils import make_workspace_streamk
from .utils import HiggsLinear
flute_workspaces = {}
flute_modules = {name: module for name, module in model.named_modules() if isinstance(module, HiggsLinear)}
for name, module in tqdm(flute_modules.items(), desc="Repacking HIGGS modules", leave=False):
# Every HiggsLinear needs a "workspace": a buffer for the unpacking operation.
# This buffer needs to be on the same device as the weights, but can be reused across modules otherwise.
if module.weight.device not in flute_workspaces:
flute_workspaces[module.weight.device] = make_workspace_streamk(device=module.weight.device)
module.workspace = flute_workspaces[module.weight.device]
# FLUTE weights are packed in a way that is optimized for a specific number of SMs (GPU streaming multiprocessors).
# If the model is loaded on a different device than the one it was saved on, we need to repack the weights.
module.tune_metadata = TuneMetaData.from_dict(self.quantization_config.tune_metadata[name])
module.weight.data, module.tune_metadata = maybe_tune_and_repack(
weight=module.weight.data,
scales=module.scales.data,
metadata=module.tune_metadata,
)
self.quantization_config.tune_metadata[name] = module.tune_metadata.to_dict()
def _dequantize(self, model):
from .utils import dequantize_higgs
model = dequantize_higgs(model)
return model
-690
View File
@@ -1,690 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
HIGGS through FLUTE (Flexible Lookup Table Engine for LUT-quantized LLMs) integration file.
Taken from:
https://github.com/huggingface/transformers/blob/d3d835d4fc145e5062d2153ac23ccd4b3e2c2cbd/src/transformers/integrations/higgs.py
"""
from math import sqrt
from ...utils import (
# TODO enable:
# is_flute_available,
# is_hadamard_available,
is_torch_available,
logging,
)
if is_torch_available():
import torch
from torch import nn
# if is_flute_available():
# if is_hadamard_available():
from fast_hadamard_transform import hadamard_transform
from flute.integrations.higgs import prepare_data_transposed
from flute.tune import TuneMetaData, qgemm_v2
logger = logging.get_logger(__name__)
def pad_to_block(tensor, dims, had_block_size, value=0):
pad_dims = [0 for _ in range(2 * len(tensor.shape))]
for dim in dims:
size = tensor.shape[dim]
next_multiple_of_1024 = ((size - 1) // had_block_size + 1) * had_block_size
delta = next_multiple_of_1024 - size
pad_dims[-2 * dim - 1] = delta
return nn.functional.pad(tensor, pad_dims, "constant", value)
def get_higgs_grid(p: int, n: int):
if (p, n) == (2, 256):
return torch.tensor(
[
[-2.501467704772949, 0.17954708635807037],
[-0.6761789321899414, 1.2728623151779175],
[-1.8025816679000854, 0.7613157629966736],
[-0.538287878036499, -2.6028504371643066],
[0.8415029644966125, -0.8600977659225464],
[0.7023013234138489, 3.3138747215270996],
[0.5699077844619751, 2.5782253742218018],
[3.292393207550049, -0.6016128063201904],
[0.5561617016792297, -1.7723814249038696],
[-2.1012380123138428, 0.020958125591278076],
[0.46085724234580994, 0.8428705334663391],
[1.4548040628433228, -0.6156039237976074],
[3.210029363632202, 0.3546904921531677],
[0.8893890976905823, -0.5967988967895508],
[0.8618854284286499, -3.2061192989349365],
[1.1360996961593628, -0.23852407932281494],
[1.6646337509155273, -0.9265465140342712],
[1.4767773151397705, 1.2476022243499756],
[-1.0511897802352905, 1.94503915309906],
[-1.56318998336792, -0.3264186680316925],
[-0.1829211413860321, 0.2922491431236267],
[-0.8950616717338562, -1.3887052536010742],
[-0.08206957578659058, -1.329533576965332],
[-0.487422913312912, 1.4817842245101929],
[-1.6769757270812988, -2.8269758224487305],
[-1.5057679414749146, 1.8905963897705078],
[1.8335362672805786, 1.0515104532241821],
[0.3273945450782776, 1.0491033792495728],
[-3.295924186706543, -0.7021600008010864],
[-1.8428784608840942, -1.2315762042999268],
[-0.8575026392936707, -1.7005949020385742],
[-1.120667815208435, 0.6467998027801514],
[-0.1588846743106842, -1.804071068763733],
[-0.8539647459983826, 0.5645008683204651],
[-1.4192019701004028, -0.6175029873847961],
[1.0799058675765991, 1.7871345281600952],
[1.171311855316162, 0.7511613965034485],
[2.162078380584717, 0.8044339418411255],
[1.3969420194625854, -1.243762493133545],
[-0.23818807303905487, 0.053944624960422516],
[2.304199457168579, -1.2667627334594727],
[1.4225027561187744, 0.568610668182373],
[0.376836895942688, -0.7134661674499512],
[2.0404467582702637, 0.4087389409542084],
[0.7639489769935608, -1.1367933750152588],
[0.3622530400753021, -1.4827953577041626],
[0.4100743532180786, 0.36108437180519104],
[-1.5867475271224976, -1.618212342262268],
[-2.2769672870635986, -1.2132309675216675],
[0.9184022545814514, -0.34428009390830994],
[-0.3902314603328705, 0.21785245835781097],
[3.120687484741211, 1.3077973127365112],
[1.587440848350525, -1.6506884098052979],
[-1.718808889389038, -0.038405973464250565],
[-0.6888407468795776, -0.8402308821678162],
[-0.7981445789337158, -1.1117373704910278],
[-2.4124443531036377, 1.3419722318649292],
[-0.6611530184745789, 0.9939885139465332],
[-0.33103418350219727, -0.16702833771705627],
[-2.4091389179229736, -2.326857566833496],
[1.6610108613967896, -2.159703254699707],
[0.014884627424180508, 0.3887578248977661],
[0.029668325558304787, 1.8786455392837524],
[1.180362582206726, 2.699317216873169],
[1.821286678314209, -0.5960053205490112],
[-0.44835323095321655, 3.327436685562134],
[-0.3714401423931122, -2.1466753482818604],
[-1.1103475093841553, -2.4536871910095215],
[-0.39110705256462097, 0.6670510172843933],
[0.474752813577652, -1.1959707736968994],
[-0.013110585510730743, -2.52519154548645],
[-2.0836575031280518, -1.703289270401001],
[-1.1077687740325928, -0.1252644956111908],
[-0.4138077199459076, 1.1837692260742188],
[-1.977599024772644, 1.688241720199585],
[-1.659559965133667, -2.1387736797332764],
[0.03242531046271324, 0.6526556015014648],
[0.9127950072288513, 0.6099498867988586],
[-0.38478314876556396, 0.433487206697464],
[0.27454206347465515, -0.27719801664352417],
[0.10388526320457458, 2.2812814712524414],
[-0.014394169673323631, -3.177137613296509],
[-1.2871228456497192, -0.8961855173110962],
[0.5720916986465454, -0.921597957611084],
[1.1159656047821045, -0.7609877586364746],
[2.4383342266082764, -2.2983546257019043],
[-0.294057160615921, -0.9770799875259399],
[-0.9342701435089111, 1.107579231262207],
[-1.549338698387146, 3.090520143508911],
[2.6076579093933105, 2.051239013671875],
[-0.9259037375450134, 1.407211184501648],
[-0.1747353971004486, 0.540488600730896],
[-0.8963701725006104, 0.8271111249923706],
[0.6480194926261902, 1.0128909349441528],
[0.980783998966217, -0.06156221032142639],
[-0.16883476078510284, 1.0601658821105957],
[0.5839992761611938, 0.004697148688137531],
[-0.34228450059890747, -1.2423977851867676],
[2.500824451446533, 0.3665279746055603],
[-0.17641609907150269, 1.3529551029205322],
[0.05378641560673714, 2.817232847213745],
[-1.2391047477722168, 2.354328155517578],
[0.630434513092041, -0.668536365032196],
[1.7576488256454468, 0.6738647818565369],
[0.4435231387615204, 0.6000469326972961],
[-0.08794835954904556, -0.11511358618736267],
[1.6540337800979614, 0.33995017409324646],
[-0.04202975332736969, -0.5375117063522339],
[-0.4247745871543884, -0.7897617220878601],
[0.06695003807544708, 1.2000739574432373],
[-3.2508881092071533, 0.28734830021858215],
[-1.613816261291504, 0.4944162368774414],
[1.3598989248275757, 0.26117825508117676],
[2.308382511138916, 1.3462618589401245],
[-1.2137469053268433, -1.9254342317581177],
[-0.4889402985572815, 1.8136259317398071],
[-0.1870335340499878, -0.3480615019798279],
[1.0766386985778809, -1.0627082586288452],
[0.4651014506816864, 2.131748914718628],
[-0.1306295394897461, -0.7811847925186157],
[0.06433182954788208, -1.5397958755493164],
[-0.2894323468208313, -0.5789554715156555],
[-0.6081662178039551, 0.4845278263092041],
[2.697964668273926, -0.18515698611736298],
[0.1277363896369934, -0.7221432328224182],
[0.8700758218765259, 0.35042452812194824],
[0.22088994085788727, 0.495242178440094],
[-2.5843818187713623, -0.8000828623771667],
[0.6732649803161621, -1.4362232685089111],
[-1.5286413431167603, 1.0417330265045166],
[-1.1222513914108276, -0.6269875764846802],
[-0.9752035140991211, -0.8750635385513306],
[-2.6369473934173584, 0.6918523907661438],
[0.14478731155395508, -0.041986867785453796],
[-1.5629483461380005, 1.4369450807571411],
[0.38952457904815674, -2.16428804397583],
[-0.16885095834732056, 0.7976621985435486],
[-3.12416934967041, 1.256506085395813],
[0.6843105554580688, -0.4203019142150879],
[1.9345275163650513, 1.934950351715088],
[0.012184220366179943, -2.1080918312072754],
[-0.6350273489952087, 0.7358828186988831],
[-0.837304949760437, -0.6214472651481628],
[0.08211923390626907, -0.9472538232803345],
[2.9332995414733887, -1.4956780672073364],
[1.3806978464126587, -0.2916182279586792],
[0.06773144006729126, 0.9285762310028076],
[-1.1943119764328003, 1.5963770151138306],
[1.6395620107650757, -0.32285431027412415],
[-1.390851378440857, -0.08273141086101532],
[1.816330909729004, -1.2812227010726929],
[0.7921574711799622, -2.1135804653167725],
[0.5817914605140686, 1.2644577026367188],
[1.929347038269043, -0.2386285960674286],
[0.8877345323562622, 1.190008521080017],
[1.4732073545455933, 0.8935023546218872],
[-2.8518524169921875, -1.5478795766830444],
[0.2439267635345459, 0.7576767802238464],
[0.5246709585189819, -2.606659412384033],
[1.150876760482788, 1.4073830842971802],
[-0.2643202245235443, 2.0634236335754395],
[1.555483341217041, -0.0023102816194295883],
[2.0830578804016113, -1.7225427627563477],
[-0.5424830317497253, -1.070199728012085],
[0.9168899655342102, 0.8955540060997009],
[-0.8120972514152527, 2.696739912033081],
[-0.29908373951911926, -1.5310651063919067],
[1.2320337295532227, -1.556247353553772],
[1.8612544536590576, 0.08704725652933121],
[0.22133447229862213, -1.8091708421707153],
[-0.4403655230998993, -0.38571012020111084],
[-1.88539457321167, 1.192205786705017],
[2.239687919616699, 0.004709010478109121],
[1.139495611190796, 0.45733731985092163],
[-1.507995367050171, 0.19716016948223114],
[0.46986445784568787, 1.5422041416168213],
[-1.2573751211166382, -0.35984551906585693],
[-1.7415345907211304, -0.6020717024803162],
[1.0751984119415283, 0.19006384909152985],
[2.24186635017395, -0.46343153715133667],
[0.3610347509384155, -0.07658443599939346],
[-1.3111497163772583, 0.432013601064682],
[0.6164408326148987, 0.24538464844226837],
[-1.9266542196273804, -0.3256155550479889],
[-0.5870336890220642, -0.1879584938287735],
[-1.0476511716842651, 0.3677721917629242],
[-1.229940414428711, 1.2433830499649048],
[0.18550436198711395, 0.22753673791885376],
[-0.017921989783644676, 0.12625974416732788],
[1.1659504175186157, -0.5020995736122131],
[-0.5983408093452454, -1.40438973903656],
[0.7519024014472961, -0.16282692551612854],
[0.9920787811279297, -1.344896912574768],
[-0.8103678226470947, 0.3064485788345337],
[0.6956969499588013, 1.8208192586898804],
[-2.7830491065979004, -0.2299390584230423],
[-0.34681546688079834, 2.4890666007995605],
[-1.4452646970748901, -1.2216600179672241],
[-2.1872897148132324, 0.8926076292991638],
[1.706072211265564, -2.8440372943878174],
[1.1119003295898438, -2.4923460483551025],
[-2.582794666290283, 2.0973289012908936],
[0.04987720400094986, -0.2964983284473419],
[-2.063807487487793, -0.7847916483879089],
[-0.4068813621997833, 0.9135897755622864],
[-0.9814359545707703, -0.3874954879283905],
[-1.4227229356765747, 0.7337291240692139],
[0.3065044581890106, 1.3125417232513428],
[1.2160996198654175, -1.9643305540084839],
[-1.2163853645324707, 0.14608727395534515],
[-2.3030710220336914, -0.37558120489120483],
[0.9232977628707886, 2.1843791007995605],
[-0.1989777386188507, 1.651851773262024],
[-0.714374840259552, -0.39365994930267334],
[-0.7805715799331665, -2.099881887435913],
[0.9015759229660034, -1.7053706645965576],
[0.1033422127366066, 1.5256654024124146],
[-1.8773194551467896, 2.324174165725708],
[1.9227174520492554, 2.7441604137420654],
[-0.5994020104408264, 0.23984014987945557],
[1.3496100902557373, -0.9126054644584656],
[-0.8765304088592529, -3.1877026557922363],
[-1.2040035724639893, -1.5169521570205688],
[1.4261796474456787, 2.150200128555298],
[1.463774561882019, 1.6656692028045654],
[0.20364105701446533, -0.4988172650337219],
[0.5195154547691345, -0.24067887663841248],
[-1.1116786003112793, -1.1599653959274292],
[-0.8490808606147766, -0.1681060940027237],
[0.3189965784549713, -0.9641751646995544],
[-0.5664751529693604, -0.5951744318008423],
[-1.6347930431365967, -0.9137664437294006],
[0.44048091769218445, -0.47259435057640076],
[-2.147747039794922, 0.47442489862442017],
[1.834734320640564, 1.4462147951126099],
[1.1777573823928833, 1.0659226179122925],
[-0.9568989872932434, 0.09495053440332413],
[-1.838529348373413, 0.2950586676597595],
[-0.4800611734390259, 0.014894310384988785],
[-0.5235516428947449, -1.7687653303146362],
[2.0735011100769043, -0.8825281262397766],
[2.637502431869507, 0.8455678224563599],
[2.606602907180786, -0.7848446369171143],
[-1.1886937618255615, 0.9330510497093201],
[0.38082656264305115, 0.13328030705451965],
[0.6847941875457764, 0.7384101152420044],
[1.2638574838638306, -0.007309418171644211],
[0.18292222917079926, -1.22371244430542],
[0.8143821954727173, 1.4976691007614136],
[0.6571850776672363, 0.48368802666664124],
[-0.6991601586341858, 2.150190830230713],
[0.8101756572723389, 0.10206498205661774],
[-0.08768226951360703, -1.084917664527893],
[-0.7208092212677002, 0.03657956421375275],
[0.3211449086666107, 1.803687334060669],
[-0.7835946083068848, 1.6869111061096191],
]
)
if (p, n) == (2, 64):
return torch.tensor(
[
[-2.7216711044311523, 0.14431366324424744],
[-0.766914427280426, 1.7193410396575928],
[-2.2575762271881104, 1.2476624250411987],
[1.233758807182312, -2.3560616970062256],
[0.8701965808868408, -0.2649352252483368],
[1.4506438970565796, 2.1776366233825684],
[-0.06305818259716034, 1.9049758911132812],
[2.536226511001587, 0.563927412033081],
[0.4599496126174927, -1.8745561838150024],
[-1.900517225265503, -0.30703988671302795],
[0.09386251866817474, 0.8755807280540466],
[1.946500539779663, -0.6743080615997314],
[2.1338934898376465, 1.4581491947174072],
[0.9429940581321716, -0.8038390278816223],
[2.0697755813598633, -1.614896535873413],
[0.772676408290863, 0.22017823159694672],
[1.0689979791641235, -1.525044322013855],
[0.6813604831695557, 1.1345642805099487],
[0.4706456661224365, 2.606626272201538],
[-1.294018030166626, -0.4372096061706543],
[-0.09134224057197571, 0.4610418677330017],
[-0.7907772064208984, -0.48412787914276123],
[0.060459110885858536, -0.9172890186309814],
[-0.5855047702789307, 2.56172513961792],
[0.11484206467866898, -2.659848213195801],
[-1.5893300771713257, 2.188580274581909],
[1.6750942468643188, 0.7089915871620178],
[-0.445697546005249, 0.7452405095100403],
[-1.8539940118789673, -1.8377939462661743],
[-1.5791912078857422, -1.017285943031311],
[-1.030419945716858, -1.5746369361877441],
[-1.9511750936508179, 0.43696075677871704],
[-0.3446580767631531, -1.8953213691711426],
[-1.4219647645950317, 0.7676230669021606],
[-0.9191089272499084, 0.5021472573280334],
[0.20464491844177246, 1.3684605360031128],
[0.5402919054031372, 0.6699410676956177],
[1.8903915882110596, 0.03638288006186485],
[0.4723062515258789, -0.6216739416122437],
[-0.41345009207725525, -0.22752176225185394],
[2.7119064331054688, -0.5111885070800781],
[1.065286636352539, 0.6950305700302124],
[0.40629103779792786, -0.14339995384216309],
[1.2815024852752686, 0.17108257114887238],
[0.01785222627222538, -0.43778058886528015],
[0.054590027779340744, -1.4225547313690186],
[0.3076786696910858, 0.30697619915008545],
[-0.9498570561408997, -0.9576997756958008],
[-2.4640724658966064, -0.9660449028015137],
[1.3714425563812256, -0.39760473370552063],
[-0.4857747256755829, 0.2386789172887802],
[1.2797833681106567, 1.3097363710403442],
[0.5508887767791748, -1.1777795553207397],
[-1.384316325187683, 0.1465839296579361],
[-0.46556955575942993, -1.2442727088928223],
[-0.3915477693080902, -0.7319604158401489],
[-1.4005504846572876, 1.3890998363494873],
[-0.8647305965423584, 1.0617644786834717],
[-0.8901953101158142, -0.01650036871433258],
[-0.9893633723258972, -2.4662880897521973],
[1.445534110069275, -1.049334168434143],
[-0.041650623083114624, 0.012734669260680676],
[-0.3302375078201294, 1.26217782497406],
[0.6934980154037476, 1.7714335918426514],
]
)
elif (p, n) == (2, 16):
return torch.tensor(
[
[-0.8996632695198059, -1.6360418796539307],
[-0.961183488368988, 1.5999565124511719],
[-1.882026195526123, 0.678778350353241],
[0.36300793290138245, -1.9667866230010986],
[-0.6814072728157043, -0.576818585395813],
[0.7270012497901917, 0.6186859607696533],
[0.3359416127204895, 1.8371193408966064],
[1.859930396080017, 0.036668598651885986],
[0.17208248376846313, -0.9401724338531494],
[-1.7599700689315796, -0.6244229674339294],
[-0.8993809223175049, 0.32267823815345764],
[0.839488685131073, -0.3017036020755768],
[1.5314953327178955, 1.2942044734954834],
[-0.0011779458727687597, 0.00022069070837460458],
[1.4274526834487915, -1.207889199256897],
[-0.16123905777931213, 0.8787511587142944],
]
)
elif (p, n) == (1, 16):
return torch.tensor(
[
[-2.7325894832611084],
[-2.069017171859741],
[-1.6180464029312134],
[-1.2562311887741089],
[-0.9423404335975647],
[-0.6567591428756714],
[-0.38804829120635986],
[-0.12839503586292267],
[0.12839503586292267],
[0.38804829120635986],
[0.6567591428756714],
[0.9423404335975647],
[1.2562311887741089],
[1.6180464029312134],
[2.069017171859741],
[2.7325894832611084],
]
)
elif (p, n) == (1, 8):
return torch.tensor(
[
[-2.1519455909729004],
[-1.3439092636108398],
[-0.7560052871704102],
[-0.2450941801071167],
[0.2450941801071167],
[0.7560052871704102],
[1.3439092636108398],
[2.1519455909729004],
]
)
elif (p, n) == (1, 4):
return torch.tensor([[-1.5104175806045532], [-0.4527800381183624], [0.4527800381183624], [1.5104175806045532]])
else:
raise NotImplementedError(f"Unsupported p={p}, n={n}")
def quantize_with_higgs(weight, bits: int = 4, p: int = 2, group_size: int = 256, hadamard_size: int = 1024):
assert len(weight.shape) == 2, "Only 2D weights are supported for now"
grid = get_higgs_grid(p, 2 ** (p * bits)).to(weight.device)
grid_norm_2 = torch.linalg.norm(grid, axis=-1) ** 2
device = weight.device
dtype = weight.dtype
weight = weight.to(copy=True, dtype=torch.float32)
# Pad to Hadamard transform size
weight = pad_to_block(weight, [1], hadamard_size)
# Scale and Hadamard transform
mult = weight.shape[1] // hadamard_size
weight = weight.reshape(-1, mult, hadamard_size)
scales = torch.linalg.norm(weight, axis=-1)
weight = hadamard_transform(weight, 1) / scales[:, :, None]
# Pad to edenn_d and project
weight = pad_to_block(weight, [2], p).reshape(weight.shape[0], mult, -1, p)
# Quantize
codes = torch.empty(weight.shape[:-1], device=device, dtype=torch.uint8)
for i in range(0, weight.shape[0], 16):
codes[i : i + 16] = torch.argmax(2 * weight[i : i + 16] @ grid.T - grid_norm_2, dim=-1).to(torch.uint8)
del weight
codes = codes.reshape(codes.shape[0], -1)
scales = scales / sqrt(hadamard_size)
weight, scales, tables, tables2, tune_metadata = prepare_data_transposed(
codes,
torch.repeat_interleave(scales.to(dtype), hadamard_size // group_size, dim=1),
grid.to(dtype),
num_bits=bits,
group_size=group_size,
vector_size=p,
dtype=dtype,
device=device,
check_correctness=False,
)
return {
"weight": weight,
"scales": scales,
"tables": tables,
"tables2": tables2.view(dtype=torch.float16),
"tune_metadata": tune_metadata,
}
class HiggsLinear(torch.nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
num_bits: int,
bias=True,
dtype: torch.dtype = None,
device: torch.device = None,
group_size: int = 256,
hadamard_size: int = 1024,
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.num_bits = num_bits
self.group_size = group_size
self.hadamard_size = hadamard_size
assert in_features % group_size == 0
assert num_bits in [2, 3, 4]
self.weight = nn.Parameter(
torch.empty((out_features * num_bits // 16, in_features), dtype=torch.int16, device=device),
requires_grad=False,
)
self.scales = nn.Parameter(
torch.empty((out_features, in_features // group_size), dtype=dtype, device=device), requires_grad=False
)
self.tables = nn.Parameter(torch.empty((2**num_bits,), dtype=dtype, device=device), requires_grad=False)
self.tables2 = nn.Parameter(
torch.empty((2**num_bits, 2**num_bits, 2), dtype=dtype, device=device), requires_grad=False
)
if bias:
self.bias = nn.Parameter(torch.empty(out_features, device=device, dtype=dtype), requires_grad=False)
else:
self.register_parameter("bias", None)
self.workspace = None # must be set externally to be reused among layers
self.tune_metadata: TuneMetaData = None # must be set externally because architecture dependent
def forward(self, x):
x = pad_to_block(x, [-1], self.hadamard_size)
if self.workspace is None:
raise Exception("Workspace must be set before calling forward")
return qgemm_v2(
x,
self.weight,
self.scales,
self.tables,
self.tables2.view(dtype=torch.float32),
self.workspace,
self.tune_metadata,
hadamard_size=self.hadamard_size,
)
def _replace_with_higgs_linear(
model, quantization_config=None, current_key_name=None, modules_to_not_convert=None, has_been_replaced=False
):
from accelerate import init_empty_weights
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)
if isinstance(module, nn.Linear):
current_key_name_str = ".".join(current_key_name)
if not any(current_key_name_str.endswith(key) for key in modules_to_not_convert):
with init_empty_weights():
in_features = module.in_features
out_features = module.out_features
# Original size is [3072, 4096]. But after `HiggsLinear`, this is
# [768, 4096]. 🤯
if name == "context_embedder":
print(f"{in_features=}, {out_features=}")
model._modules[name] = HiggsLinear(
in_features,
out_features,
bias=module.bias is not None,
num_bits=quantization_config.bits,
hadamard_size=quantization_config.hadamard_size,
group_size=quantization_config.group_size,
)
if name == "context_embedder":
print(model._modules[name].weight.shape)
has_been_replaced = True
# Store the module class in case we need to transpose the weight later
model._modules[name].source_cls = type(module)
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_higgs_linear(
module,
quantization_config=quantization_config,
current_key_name=current_key_name,
modules_to_not_convert=modules_to_not_convert,
has_been_replaced=has_been_replaced,
)
# Remove the last key for recursion
current_key_name.pop(-1)
return model, has_been_replaced
def replace_with_higgs_linear(
model,
quantization_config=None,
current_key_name=None,
has_been_replaced=False,
):
"""
Public method that recursively replaces the Linear layers of the given model with HIGGS quantized layers.
`accelerate` is needed to use this method. Returns the converted model and a boolean that indicates if the
conversion has been successful or not.
Args:
model (`torch.nn.Module`):
The model to convert, can be any `torch.nn.Module` instance.
quantization_config (`HiggsConfig`):
The quantization config object that contains the quantization parameters.
current_key_name (`list`, *optional*):
A list that contains the current key name. This is used for recursion and should not be passed by the user.
has_been_replaced (`bool`, *optional*):
A boolean that indicates if the conversion has been successful or not. This is used for recursion and
should not be passed by the user.
"""
modules_to_not_convert = quantization_config.modules_to_not_convert or []
model, _ = _replace_with_higgs_linear(
model, quantization_config, current_key_name, modules_to_not_convert, has_been_replaced
)
has_been_replaced = any(isinstance(replaced_module, HiggsLinear) for _, replaced_module in model.named_modules())
if not has_been_replaced:
logger.warning(
"You are loading your model in Higgs but no linear modules were found in your model."
" Please double check your model architecture, or submit an issue on github if you think this is"
" a bug."
)
return model
def dequantize_higgs(model, current_key_name=None):
"""
Dequantizes the HiggsLinear layers in the given model by replacing them with standard torch.nn.Linear layers.
Args:
model (torch.nn.Module): The model containing HiggsLinear layers to be dequantized.
current_key_name (list, optional):
A list to keep track of the current module names during recursion. Defaults to None.
Returns:
torch.nn.Module: The model with HiggsLinear layers replaced by torch.nn.Linear layers.
"""
with torch.no_grad():
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)
if isinstance(module, HiggsLinear):
in_features = module.in_features
out_features = module.out_features
model._modules[name] = torch.nn.Linear(
in_features,
out_features,
bias=module.bias is not None,
device=module.scales.device,
dtype=module.scales.dtype,
)
model._modules[name].weight.data = module(
torch.eye(in_features, device=module.scales.device, dtype=module.scales.dtype)
).T.contiguous()
if len(list(module.children())) > 0:
_ = dequantize_higgs(
module,
current_key_name=current_key_name,
)
# Remove the last key for recursion
current_key_name.pop(-1)
return model
@@ -46,7 +46,6 @@ class QuantizationMethod(str, Enum):
GGUF = "gguf"
TORCHAO = "torchao"
QUANTO = "quanto"
HIGGS = "higgs"
if is_torchao_available():
@@ -725,62 +724,3 @@ class QuantoConfig(QuantizationConfigMixin):
accepted_weights = ["float8", "int8", "int4", "int2"]
if self.weights_dtype not in accepted_weights:
raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights_dtype}")
@dataclass
class HiggsConfig(QuantizationConfigMixin):
"""
HiggsConfig is a configuration class for quantization using the HIGGS method.
Args:
bits (int, *optional*, defaults to 4):
Number of bits to use for quantization. Can be 2, 3 or 4. Default is 4.
p (int, *optional*, defaults to 2):
Quantization grid dimension. 1 and 2 are supported. 2 is always better in practice. Default is 2.
modules_to_not_convert (`list`, *optional*, default to ["lm_head"]):
List of linear layers that should not be quantized.
hadamard_size (int, *optional*, defaults to 512):
Hadamard size for the HIGGS method. Default is 512. Input dimension of matrices is padded to this value.
Decreasing this below 512 will reduce the quality of the quantization.
group_size (int, *optional*, defaults to 256):
Group size for the HIGGS method. Can be 64, 128 or 256. Decreasing it barely affects the performance.
Default is 256. Must be a divisor of hadamard_size.
tune_metadata ('dict', *optional*, defaults to {}):
Module-wise metadata (gemm block shapes, GPU metadata, etc.) for saving the kernel tuning results. Default
is an empty dictionary. Is set automatically during tuning.
"""
def __init__(
self,
bits: int = 4,
p: int = 2,
modules_to_not_convert: Optional[list[str]] = None,
hadamard_size: int = 512,
group_size: int = 256,
tune_metadata: Optional[dict[str, Any]] = None,
**kwargs,
):
if tune_metadata is None:
tune_metadata = {}
self.quant_method = QuantizationMethod.HIGGS
self.bits = bits
self.p = p
self.modules_to_not_convert = modules_to_not_convert
self.hadamard_size = hadamard_size
self.group_size = group_size
self.tune_metadata = tune_metadata
self.post_init()
def post_init(self):
r"""
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
"""
if self.bits not in [2, 3, 4]:
raise ValueError("bits must be 2, 3, or 4")
if self.p not in [1, 2]:
raise ValueError("p must be 1 or 2. 2 is always better in practice")
if self.group_size not in [64, 128, 256]:
raise ValueError("group_size must be 64, 128, or 256")
if self.hadamard_size % self.group_size != 0:
raise ValueError("hadamard_size must be divisible by group_size")
-1
View File
@@ -81,7 +81,6 @@ from .import_utils import (
is_invisible_watermark_available,
is_k_diffusion_available,
is_k_diffusion_version,
is_kernels_available,
is_librosa_available,
is_matplotlib_available,
is_nltk_available,
-5
View File
@@ -192,7 +192,6 @@ _torch_xla_available, _torch_xla_version = _is_package_available("torch_xla")
_torch_npu_available, _torch_npu_version = _is_package_available("torch_npu")
_transformers_available, _transformers_version = _is_package_available("transformers")
_hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub")
_kernels_available, _kernels_version = _is_package_available("kernels")
_inflect_available, _inflect_version = _is_package_available("inflect")
_unidecode_available, _unidecode_version = _is_package_available("unidecode")
_k_diffusion_available, _k_diffusion_version = _is_package_available("k_diffusion")
@@ -278,10 +277,6 @@ def is_accelerate_available():
return _accelerate_available
def is_kernels_available():
return _kernels_available
def is_k_diffusion_available():
return _k_diffusion_available
-13
View File
@@ -36,7 +36,6 @@ from .import_utils import (
is_compel_available,
is_flax_available,
is_gguf_available,
is_kernels_available,
is_note_seq_available,
is_onnx_available,
is_opencv_available,
@@ -635,18 +634,6 @@ def require_torchao_version_greater_or_equal(torchao_version):
return decorator
def require_kernels_version_greater_or_equal(kernels_version):
def decorator(test_case):
correct_kernels_version = is_kernels_available() and version.parse(
version.parse(importlib.metadata.version("kernels")).base_version
) >= version.parse(kernels_version)
return unittest.skipUnless(
correct_kernels_version, f"Test requires kernels with version greater than {kernels_version}."
)(test_case)
return decorator
def deprecate_after_peft_backend(test_case):
"""
Decorator marking a test that will be skipped after PEFT backend
-87
View File
@@ -17,9 +17,7 @@ import gc
import unittest
import torch
from parameterized import parameterized
from diffusers.hooks import HookRegistry, ModelHook
from diffusers.models import ModelMixin
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.utils import get_logger
@@ -101,29 +99,6 @@ class DummyModelWithMultipleBlocks(ModelMixin):
return x
# Test for https://github.com/huggingface/diffusers/pull/12077
class DummyModelWithLayerNorm(ModelMixin):
def __init__(self, in_features: int, hidden_features: int, out_features: int, num_layers: int) -> None:
super().__init__()
self.linear_1 = torch.nn.Linear(in_features, hidden_features)
self.activation = torch.nn.ReLU()
self.blocks = torch.nn.ModuleList(
[DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)]
)
self.layer_norm = torch.nn.LayerNorm(hidden_features, elementwise_affine=True)
self.linear_2 = torch.nn.Linear(hidden_features, out_features)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear_1(x)
x = self.activation(x)
for block in self.blocks:
x = block(x)
x = self.layer_norm(x)
x = self.linear_2(x)
return x
class DummyPipeline(DiffusionPipeline):
model_cpu_offload_seq = "model"
@@ -138,16 +113,6 @@ class DummyPipeline(DiffusionPipeline):
return x
class LayerOutputTrackerHook(ModelHook):
def __init__(self):
super().__init__()
self.outputs = []
def post_forward(self, module, output):
self.outputs.append(output)
return output
@require_torch_accelerator
class GroupOffloadTests(unittest.TestCase):
in_features = 64
@@ -293,7 +258,6 @@ class GroupOffloadTests(unittest.TestCase):
def test_block_level_stream_with_invocation_order_different_from_initialization_order(self):
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return
model = DummyModelWithMultipleBlocks(
in_features=self.in_features,
hidden_features=self.hidden_features,
@@ -310,54 +274,3 @@ class GroupOffloadTests(unittest.TestCase):
with context:
model(self.input)
@parameterized.expand([("block_level",), ("leaf_level",)])
def test_block_level_offloading_with_parameter_only_module_group(self, offload_type: str):
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return
def apply_layer_output_tracker_hook(model: DummyModelWithLayerNorm):
for name, module in model.named_modules():
registry = HookRegistry.check_if_exists_or_initialize(module)
hook = LayerOutputTrackerHook()
registry.register_hook(hook, "layer_output_tracker")
model_ref = DummyModelWithLayerNorm(128, 256, 128, 2)
model = DummyModelWithLayerNorm(128, 256, 128, 2)
model.load_state_dict(model_ref.state_dict(), strict=True)
model_ref.to(torch_device)
model.enable_group_offload(torch_device, offload_type=offload_type, num_blocks_per_group=1, use_stream=True)
apply_layer_output_tracker_hook(model_ref)
apply_layer_output_tracker_hook(model)
x = torch.randn(2, 128).to(torch_device)
out_ref = model_ref(x)
out = model(x)
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match.")
num_repeats = 4
for i in range(num_repeats):
out_ref = model_ref(x)
out = model(x)
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match after multiple invocations.")
for (ref_name, ref_module), (name, module) in zip(model_ref.named_modules(), model.named_modules()):
assert ref_name == name
ref_outputs = (
HookRegistry.check_if_exists_or_initialize(ref_module).get_hook("layer_output_tracker").outputs
)
outputs = HookRegistry.check_if_exists_or_initialize(module).get_hook("layer_output_tracker").outputs
cumulated_absmax = 0.0
for i in range(len(outputs)):
diff = ref_outputs[0] - outputs[i]
absdiff = diff.abs()
absmax = absdiff.max().item()
cumulated_absmax += absmax
self.assertLess(
cumulated_absmax, 1e-5, f"Output differences for {name} exceeded threshold: {cumulated_absmax:.5f}"
)
-129
View File
@@ -1,129 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
from diffusers import (
AutoencoderKLQwenImage,
FlowMatchEulerDiscreteScheduler,
QwenImagePipeline,
QwenImageTransformer2DModel,
)
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend
sys.path.append(".")
from utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = QwenImagePipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
"patch_size": 2,
"in_channels": 16,
"out_channels": 4,
"num_layers": 2,
"attention_head_dim": 16,
"num_attention_heads": 3,
"joint_attention_dim": 16,
"guidance_embeds": False,
"axes_dims_rope": (8, 4, 4),
}
transformer_cls = QwenImageTransformer2DModel
z_dim = 4
vae_kwargs = {
"base_dim": z_dim * 6,
"z_dim": z_dim,
"dim_mult": [1, 2, 4],
"num_res_blocks": 1,
"temperal_downsample": [False, True],
"latents_mean": [0.0] * 4,
"latents_std": [1.0] * 4,
}
vae_cls = AutoencoderKLQwenImage
tokenizer_cls, tokenizer_id = Qwen2Tokenizer, "hf-internal-testing/tiny-random-Qwen25VLForCondGen"
text_encoder_cls, text_encoder_id = (
Qwen2_5_VLForConditionalGeneration,
"hf-internal-testing/tiny-random-Qwen25VLForCondGen",
)
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
@property
def output_shape(self):
return (1, 8, 8, 3)
def get_dummy_inputs(self, with_generator=True):
batch_size = 1
sequence_length = 10
num_channels = 4
sizes = (32, 32)
generator = torch.manual_seed(0)
noise = floats_tensor((batch_size, num_channels) + sizes)
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
pipeline_inputs = {
"prompt": "A painting of a squirrel eating a burger",
"num_inference_steps": 4,
"guidance_scale": 0.0,
"height": 8,
"width": 8,
"output_type": "np",
}
if with_generator:
pipeline_inputs.update({"generator": generator})
return noise, input_ids, pipeline_inputs
@unittest.skip("Not supported in Qwen Image.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in Qwen Image.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in Qwen Image.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
def test_simple_inference_with_text_lora_save_load(self):
pass
+1 -11
View File
@@ -45,7 +45,6 @@ from diffusers import (
LMSDiscreteScheduler,
PNDMScheduler,
)
from diffusers.utils import is_transformers_version
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
@@ -221,11 +220,6 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
}
return inputs
@pytest.mark.xfail(
condition=is_transformers_version(">=", "4.54.1"),
reason="Test currently fails on Transformers version 4.54.1.",
strict=False,
)
def test_audioldm2_ddim(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -318,6 +312,7 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
components = self.get_dummy_components()
audioldm_pipe = AudioLDM2Pipeline(**components)
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
@@ -376,11 +371,6 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
assert np.abs(audio_1 - audio_2).max() < 1e-2
@pytest.mark.xfail(
condition=is_transformers_version(">=", "4.54.1"),
reason="Test currently fails on Transformers version 4.54.1.",
strict=False,
)
def test_audioldm2_negative_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
+1 -1
View File
@@ -160,7 +160,7 @@ class QwenImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
self.assertEqual(generated_image.shape, (3, 32, 32))
# fmt: off
expected_slice = torch.tensor([0.56331, 0.63677, 0.6015, 0.56369, 0.58166, 0.55277, 0.57176, 0.63261, 0.41466, 0.35561, 0.56229, 0.48334, 0.49714, 0.52622, 0.40872, 0.50208])
expected_slice = torch.tensor([0.563, 0.6358, 0.6028, 0.5656, 0.5806, 0.5512, 0.5712, 0.6331, 0.4147, 0.3558, 0.5625, 0.4831, 0.4957, 0.5258, 0.4075, 0.5018])
# fmt: on
generated_slice = generated_image.flatten()
+1 -58
View File
@@ -30,10 +30,8 @@ from diffusers.utils.testing_utils import (
nightly,
numpy_cosine_similarity_distance,
require_accelerate,
require_accelerator,
require_big_accelerator,
require_gguf_version_greater_or_equal,
require_kernels_version_greater_or_equal,
require_peft_backend,
require_torch_version_greater,
torch_device,
@@ -43,66 +41,11 @@ from ..test_torch_compile_utils import QuantCompileTests
if is_gguf_available():
import gguf
from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter
enable_full_determinism()
@nightly
@require_accelerate
@require_accelerator
@require_gguf_version_greater_or_equal("0.10.0")
@require_kernels_version_greater_or_equal("0.9.0")
class GGUFCudaKernelsTests(unittest.TestCase):
def setUp(self):
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
gc.collect()
backend_empty_cache(torch_device)
def test_cuda_kernels_vs_native(self):
if torch_device != "cuda":
self.skipTest("CUDA kernels test requires CUDA device")
from diffusers.quantizers.gguf.utils import GGUFLinear, can_use_cuda_kernels
if not can_use_cuda_kernels:
self.skipTest("CUDA kernels not available (compute capability < 7 or kernels not installed)")
test_quant_types = ["Q4_0", "Q4_K"]
test_shape = (1, 64, 512) # batch, seq_len, hidden_dim
compute_dtype = torch.bfloat16
for quant_type in test_quant_types:
qtype = getattr(gguf.GGMLQuantizationType, quant_type)
in_features, out_features = 512, 512
torch.manual_seed(42)
float_weight = torch.randn(out_features, in_features, dtype=torch.float32)
quantized_data = gguf.quants.quantize(float_weight.numpy(), qtype)
weight_data = torch.from_numpy(quantized_data).to(device=torch_device)
weight = GGUFParameter(weight_data, quant_type=qtype)
x = torch.randn(test_shape, dtype=compute_dtype, device=torch_device)
linear = GGUFLinear(in_features, out_features, bias=True, compute_dtype=compute_dtype)
linear.weight = weight
linear.bias = nn.Parameter(torch.randn(out_features, dtype=compute_dtype))
linear = linear.to(torch_device)
with torch.no_grad():
output_native = linear.forward_native(x)
output_cuda = linear.forward_cuda(x)
assert torch.allclose(output_native, output_cuda, 1e-2), (
f"GGUF CUDA Kernel Output is different from Native Output for {quant_type}"
)
@nightly
@require_big_accelerator
@require_accelerate
@@ -650,7 +593,7 @@ class WanGGUFTexttoVideoSingleFileTests(GGUFSingleFileTesterMixin, unittest.Test
def get_dummy_inputs(self):
return {
"hidden_states": torch.randn((1, 16, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
"hidden_states": torch.randn((1, 36, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
torch_device, self.torch_dtype
),
"encoder_hidden_states": torch.randn(