Compare commits
18 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 560fb5f4d6 | |||
| 8ab26ac9bf | |||
| 9f305e7ce2 | |||
| 2c25bf5bef | |||
| 0e14cacffc | |||
| 13ea83f0fa | |||
| 2b432ac5a8 | |||
| 263b973466 | |||
| a663a67ea2 | |||
| 526858c801 | |||
| b3d2dd36d7 | |||
| abfa922410 | |||
| 6a7b01f60f | |||
| e8aacda762 | |||
| 12184f4015 | |||
| 6e1d2da194 | |||
| 11b1151840 | |||
| cd4d0d8ffb |
@@ -359,6 +359,8 @@ jobs:
|
||||
test_location: "bnb"
|
||||
- backend: "gguf"
|
||||
test_location: "gguf"
|
||||
- backend: "torchao"
|
||||
test_location: "torchao"
|
||||
runs-on:
|
||||
group: aws-g6e-xlarge-plus
|
||||
container:
|
||||
|
||||
@@ -83,7 +83,7 @@ jobs:
|
||||
python utils/print_env.py
|
||||
- name: PyTorch CUDA checkpoint tests on Ubuntu
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
@@ -137,7 +137,7 @@ jobs:
|
||||
|
||||
- name: Run PyTorch CUDA tests
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
|
||||
@@ -68,7 +68,7 @@ jobs:
|
||||
- name: Test installing diffusers and importing
|
||||
run: |
|
||||
pip install diffusers && pip uninstall diffusers -y
|
||||
pip install -i https://testpypi.python.org/pypi diffusers
|
||||
pip install -i https://test.pypi.org/simple/ diffusers
|
||||
python -c "from diffusers import __version__; print(__version__)"
|
||||
python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('fusing/unet-ldm-dummy-update'); pipe()"
|
||||
python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('hf-internal-testing/tiny-stable-diffusion-pipe', safety_checker=None); pipe('ah suh du')"
|
||||
|
||||
@@ -305,6 +305,10 @@ image = control_pipe(
|
||||
image.save("output.png")
|
||||
```
|
||||
|
||||
## Note about `unload_lora_weights()` when using Flux LoRAs
|
||||
|
||||
When unloading the Control LoRA weights, call `pipe.unload_lora_weights(reset_to_overwritten_params=True)` to reset the `pipe.transformer` completely back to its original form. The resultant pipeline can then be used with methods like [`DiffusionPipeline.from_pipe`]. More details about this argument are available in [this PR](https://github.com/huggingface/diffusers/pull/10397).
|
||||
|
||||
## Running FP16 inference
|
||||
|
||||
Flux can generate high-quality images with FP16 (i.e. to accelerate inference on Turing/Volta GPUs) but produces different outputs compared to FP32/BF16. The issue is that some activations in the text encoders have to be clipped when running in FP16, which affects the overall image. Forcing text encoders to run with FP32 inference thus removes this output difference. See [here](https://github.com/huggingface/diffusers/pull/9097#issuecomment-2272292516) for details.
|
||||
|
||||
@@ -25,6 +25,7 @@ Quantize a model by passing [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`]
|
||||
The example below only quantizes the weights to int8.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
|
||||
|
||||
model_id = "black-forest-labs/FLUX.1-dev"
|
||||
@@ -44,6 +45,10 @@ pipe = FluxPipeline.from_pretrained(
|
||||
)
|
||||
pipe.to("cuda")
|
||||
|
||||
# Without quantization: ~31.447 GB
|
||||
# With quantization: ~20.40 GB
|
||||
print(f"Pipeline memory usage: {torch.cuda.max_memory_reserved() / 1024**3:.3f} GB")
|
||||
|
||||
prompt = "A cat holding a sign that says hello world"
|
||||
image = pipe(
|
||||
prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
|
||||
@@ -88,6 +93,63 @@ Some quantization methods are aliases (for example, `int8wo` is the commonly use
|
||||
|
||||
Refer to the official torchao documentation for a better understanding of the available quantization methods and the exhaustive list of configuration options available.
|
||||
|
||||
## Serializing and Deserializing quantized models
|
||||
|
||||
To serialize a quantized model in a given dtype, first load the model with the desired quantization dtype and then save it using the [`~ModelMixin.save_pretrained`] method.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxTransformer2DModel, TorchAoConfig
|
||||
|
||||
quantization_config = TorchAoConfig("int8wo")
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/Flux.1-Dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
transformer.save_pretrained("/path/to/flux_int8wo", safe_serialization=False)
|
||||
```
|
||||
|
||||
To load a serialized quantized model, use the [`~ModelMixin.from_pretrained`] method.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxPipeline, FluxTransformer2DModel
|
||||
|
||||
transformer = FluxTransformer2DModel.from_pretrained("/path/to/flux_int8wo", torch_dtype=torch.bfloat16, use_safetensors=False)
|
||||
pipe = FluxPipeline.from_pretrained("black-forest-labs/Flux.1-Dev", transformer=transformer, torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = "A cat holding a sign that says hello world"
|
||||
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0).images[0]
|
||||
image.save("output.png")
|
||||
```
|
||||
|
||||
Some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
|
||||
|
||||
# Serialize the model
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/Flux.1-Dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=TorchAoConfig("uint4wo"),
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
transformer.save_pretrained("/path/to/flux_uint4wo", safe_serialization=False, max_shard_size="50GB")
|
||||
# ...
|
||||
|
||||
# Load the model
|
||||
state_dict = torch.load("/path/to/flux_uint4wo/diffusion_pytorch_model.bin", weights_only=False, map_location="cpu")
|
||||
with init_empty_weights():
|
||||
transformer = FluxTransformer2DModel.from_config("/path/to/flux_uint4wo/config.json")
|
||||
transformer.load_state_dict(state_dict, strict=True, assign=True)
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- [TorchAO Quantization API](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md)
|
||||
|
||||
@@ -74,7 +74,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -73,7 +73,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -43,8 +43,7 @@ from diffusers.utils import BaseOutput, check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
|
||||
check_min_version("0.32.0")
|
||||
|
||||
class MarigoldDepthOutput(BaseOutput):
|
||||
"""
|
||||
|
||||
@@ -73,7 +73,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -78,7 +78,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
|
||||
@@ -59,7 +59,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
|
||||
@@ -63,7 +63,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
# Cache compiled models across invocations of this script.
|
||||
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
||||
|
||||
@@ -65,7 +65,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -70,7 +70,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -70,7 +70,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
|
||||
@@ -55,7 +55,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
|
||||
@@ -81,7 +81,7 @@ else:
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ else:
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -76,7 +76,7 @@ else:
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.32.0.dev0")
|
||||
check_min_version("0.32.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -254,7 +254,7 @@ version_range_max = max(sys.version_info[1], 10) + 1
|
||||
|
||||
setup(
|
||||
name="diffusers",
|
||||
version="0.32.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="0.32.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
description="State-of-the-art diffusion in PyTorch and JAX.",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
__version__ = "0.32.0.dev0"
|
||||
__version__ = "0.32.2"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
@@ -28,13 +28,20 @@ from ..models.modeling_utils import ModelMixin, load_state_dict
|
||||
from ..utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
_get_model_file,
|
||||
convert_state_dict_to_diffusers,
|
||||
convert_state_dict_to_peft,
|
||||
delete_adapter_layers,
|
||||
deprecate,
|
||||
get_adapter_name,
|
||||
get_peft_kwargs,
|
||||
is_accelerate_available,
|
||||
is_peft_available,
|
||||
is_peft_version,
|
||||
is_transformers_available,
|
||||
is_transformers_version,
|
||||
logging,
|
||||
recurse_remove_peft_layers,
|
||||
scale_lora_layers,
|
||||
set_adapter_layers,
|
||||
set_weights_and_activate_adapters,
|
||||
)
|
||||
@@ -43,6 +50,8 @@ from ..utils import (
|
||||
if is_transformers_available():
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
|
||||
|
||||
if is_peft_available():
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
@@ -297,6 +306,152 @@ def _best_guess_weight_name(
|
||||
return weight_name
|
||||
|
||||
|
||||
def _load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
network_alphas,
|
||||
text_encoder,
|
||||
prefix=None,
|
||||
lora_scale=1.0,
|
||||
text_encoder_name="text_encoder",
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
):
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
peft_kwargs = {}
|
||||
if low_cpu_mem_usage:
|
||||
if not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
if not is_transformers_version(">", "4.45.2"):
|
||||
# Note from sayakpaul: It's not in `transformers` stable yet.
|
||||
# https://github.com/huggingface/transformers/pull/33725/
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
|
||||
)
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
|
||||
from peft import LoraConfig
|
||||
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
|
||||
# their prefixes.
|
||||
keys = list(state_dict.keys())
|
||||
prefix = text_encoder_name if prefix is None else prefix
|
||||
|
||||
# Safe prefix to check with.
|
||||
if any(text_encoder_name in key for key in keys):
|
||||
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||
text_encoder_lora_state_dict = {
|
||||
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
|
||||
}
|
||||
|
||||
if len(text_encoder_lora_state_dict) > 0:
|
||||
logger.info(f"Loading {prefix}.")
|
||||
rank = {}
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
|
||||
|
||||
# convert state dict
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
|
||||
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in text_encoder_lora_state_dict:
|
||||
continue
|
||||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
||||
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
for module in ("fc1", "fc2"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in text_encoder_lora_state_dict:
|
||||
continue
|
||||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
||||
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
||||
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
|
||||
if "lora_bias" in lora_config_kwargs:
|
||||
if lora_config_kwargs["lora_bias"]:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
lora_config_kwargs.pop("lora_bias")
|
||||
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(text_encoder)
|
||||
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
|
||||
|
||||
# inject LoRA layers and load the state dict
|
||||
# in transformers we automatically check whether the adapter name is already in use or not
|
||||
text_encoder.load_adapter(
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=text_encoder_lora_state_dict,
|
||||
peft_config=lora_config,
|
||||
**peft_kwargs,
|
||||
)
|
||||
|
||||
# scale LoRA layers with `lora_scale`
|
||||
scale_lora_layers(text_encoder, weight=lora_scale)
|
||||
|
||||
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
|
||||
|
||||
def _func_optionally_disable_offloading(_pipeline):
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
|
||||
if _pipeline is not None and _pipeline.hf_device_map is None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
||||
if not is_model_cpu_offload:
|
||||
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
|
||||
if not is_sequential_cpu_offload:
|
||||
is_sequential_cpu_offload = (
|
||||
isinstance(component._hf_hook, AlignDevicesHook)
|
||||
or hasattr(component._hf_hook, "hooks")
|
||||
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
|
||||
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
||||
|
||||
|
||||
class LoraBaseMixin:
|
||||
"""Utility class for handling LoRAs."""
|
||||
|
||||
@@ -327,27 +482,7 @@ class LoraBaseMixin:
|
||||
tuple:
|
||||
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
||||
"""
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
|
||||
if _pipeline is not None and _pipeline.hf_device_map is None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
||||
if not is_model_cpu_offload:
|
||||
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
|
||||
if not is_sequential_cpu_offload:
|
||||
is_sequential_cpu_offload = (
|
||||
isinstance(component._hf_hook, AlignDevicesHook)
|
||||
or hasattr(component._hf_hook, "hooks")
|
||||
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
|
||||
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
||||
return _func_optionally_disable_offloading(_pipeline=_pipeline)
|
||||
|
||||
@classmethod
|
||||
def _fetch_state_dict(cls, *args, **kwargs):
|
||||
|
||||
@@ -973,3 +973,178 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
|
||||
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def _convert_hunyuan_video_lora_to_diffusers(original_state_dict):
|
||||
converted_state_dict = {k: original_state_dict.pop(k) for k in list(original_state_dict.keys())}
|
||||
|
||||
def remap_norm_scale_shift_(key, state_dict):
|
||||
weight = state_dict.pop(key)
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight
|
||||
|
||||
def remap_txt_in_(key, state_dict):
|
||||
def rename_key(key):
|
||||
new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
|
||||
new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
|
||||
new_key = new_key.replace("txt_in", "context_embedder")
|
||||
new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
|
||||
new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
|
||||
new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
|
||||
new_key = new_key.replace("mlp", "ff")
|
||||
return new_key
|
||||
|
||||
if "self_attn_qkv" in key:
|
||||
weight = state_dict.pop(key)
|
||||
to_q, to_k, to_v = weight.chunk(3, dim=0)
|
||||
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
|
||||
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
|
||||
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
|
||||
else:
|
||||
state_dict[rename_key(key)] = state_dict.pop(key)
|
||||
|
||||
def remap_img_attn_qkv_(key, state_dict):
|
||||
weight = state_dict.pop(key)
|
||||
if "lora_A" in key:
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = weight
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = weight
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = weight
|
||||
else:
|
||||
to_q, to_k, to_v = weight.chunk(3, dim=0)
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v
|
||||
|
||||
def remap_txt_attn_qkv_(key, state_dict):
|
||||
weight = state_dict.pop(key)
|
||||
if "lora_A" in key:
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = weight
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = weight
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = weight
|
||||
else:
|
||||
to_q, to_k, to_v = weight.chunk(3, dim=0)
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v
|
||||
|
||||
def remap_single_transformer_blocks_(key, state_dict):
|
||||
hidden_size = 3072
|
||||
|
||||
if "linear1.lora_A.weight" in key or "linear1.lora_B.weight" in key:
|
||||
linear1_weight = state_dict.pop(key)
|
||||
if "lora_A" in key:
|
||||
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
|
||||
".linear1.lora_A.weight"
|
||||
)
|
||||
state_dict[f"{new_key}.attn.to_q.lora_A.weight"] = linear1_weight
|
||||
state_dict[f"{new_key}.attn.to_k.lora_A.weight"] = linear1_weight
|
||||
state_dict[f"{new_key}.attn.to_v.lora_A.weight"] = linear1_weight
|
||||
state_dict[f"{new_key}.proj_mlp.lora_A.weight"] = linear1_weight
|
||||
else:
|
||||
split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
|
||||
q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
|
||||
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
|
||||
".linear1.lora_B.weight"
|
||||
)
|
||||
state_dict[f"{new_key}.attn.to_q.lora_B.weight"] = q
|
||||
state_dict[f"{new_key}.attn.to_k.lora_B.weight"] = k
|
||||
state_dict[f"{new_key}.attn.to_v.lora_B.weight"] = v
|
||||
state_dict[f"{new_key}.proj_mlp.lora_B.weight"] = mlp
|
||||
|
||||
elif "linear1.lora_A.bias" in key or "linear1.lora_B.bias" in key:
|
||||
linear1_bias = state_dict.pop(key)
|
||||
if "lora_A" in key:
|
||||
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
|
||||
".linear1.lora_A.bias"
|
||||
)
|
||||
state_dict[f"{new_key}.attn.to_q.lora_A.bias"] = linear1_bias
|
||||
state_dict[f"{new_key}.attn.to_k.lora_A.bias"] = linear1_bias
|
||||
state_dict[f"{new_key}.attn.to_v.lora_A.bias"] = linear1_bias
|
||||
state_dict[f"{new_key}.proj_mlp.lora_A.bias"] = linear1_bias
|
||||
else:
|
||||
split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
|
||||
q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
|
||||
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
|
||||
".linear1.lora_B.bias"
|
||||
)
|
||||
state_dict[f"{new_key}.attn.to_q.lora_B.bias"] = q_bias
|
||||
state_dict[f"{new_key}.attn.to_k.lora_B.bias"] = k_bias
|
||||
state_dict[f"{new_key}.attn.to_v.lora_B.bias"] = v_bias
|
||||
state_dict[f"{new_key}.proj_mlp.lora_B.bias"] = mlp_bias
|
||||
|
||||
else:
|
||||
new_key = key.replace("single_blocks", "single_transformer_blocks")
|
||||
new_key = new_key.replace("linear2", "proj_out")
|
||||
new_key = new_key.replace("q_norm", "attn.norm_q")
|
||||
new_key = new_key.replace("k_norm", "attn.norm_k")
|
||||
state_dict[new_key] = state_dict.pop(key)
|
||||
|
||||
TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"img_in": "x_embedder",
|
||||
"time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
|
||||
"time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
|
||||
"guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
|
||||
"guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
|
||||
"vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
|
||||
"vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
|
||||
"double_blocks": "transformer_blocks",
|
||||
"img_attn_q_norm": "attn.norm_q",
|
||||
"img_attn_k_norm": "attn.norm_k",
|
||||
"img_attn_proj": "attn.to_out.0",
|
||||
"txt_attn_q_norm": "attn.norm_added_q",
|
||||
"txt_attn_k_norm": "attn.norm_added_k",
|
||||
"txt_attn_proj": "attn.to_add_out",
|
||||
"img_mod.linear": "norm1.linear",
|
||||
"img_norm1": "norm1.norm",
|
||||
"img_norm2": "norm2",
|
||||
"img_mlp": "ff",
|
||||
"txt_mod.linear": "norm1_context.linear",
|
||||
"txt_norm1": "norm1.norm",
|
||||
"txt_norm2": "norm2_context",
|
||||
"txt_mlp": "ff_context",
|
||||
"self_attn_proj": "attn.to_out.0",
|
||||
"modulation.linear": "norm.linear",
|
||||
"pre_norm": "norm.norm",
|
||||
"final_layer.norm_final": "norm_out.norm",
|
||||
"final_layer.linear": "proj_out",
|
||||
"fc1": "net.0.proj",
|
||||
"fc2": "net.2",
|
||||
"input_embedder": "proj_in",
|
||||
}
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"txt_in": remap_txt_in_,
|
||||
"img_attn_qkv": remap_img_attn_qkv_,
|
||||
"txt_attn_qkv": remap_txt_attn_qkv_,
|
||||
"single_blocks": remap_single_transformer_blocks_,
|
||||
"final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
|
||||
}
|
||||
|
||||
# Some folks attempt to make their state dict compatible with diffusers by adding "transformer." prefix to all keys
|
||||
# and use their custom code. To make sure both "original" and "attempted diffusers" loras work as expected, we make
|
||||
# sure that both follow the same initial format by stripping off the "transformer." prefix.
|
||||
for key in list(converted_state_dict.keys()):
|
||||
if key.startswith("transformer."):
|
||||
converted_state_dict[key[len("transformer.") :]] = converted_state_dict.pop(key)
|
||||
if key.startswith("diffusion_model."):
|
||||
converted_state_dict[key[len("diffusion_model.") :]] = converted_state_dict.pop(key)
|
||||
|
||||
# Rename and remap the state dict keys
|
||||
for key in list(converted_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
converted_state_dict[new_key] = converted_state_dict.pop(key)
|
||||
|
||||
for key in list(converted_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, converted_state_dict)
|
||||
|
||||
# Add back the "transformer." prefix
|
||||
for key in list(converted_state_dict.keys()):
|
||||
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -20,7 +20,6 @@ from typing import Dict, List, Optional, Union
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..utils import (
|
||||
MIN_PEFT_VERSION,
|
||||
@@ -30,20 +29,16 @@ from ..utils import (
|
||||
delete_adapter_layers,
|
||||
get_adapter_name,
|
||||
get_peft_kwargs,
|
||||
is_accelerate_available,
|
||||
is_peft_available,
|
||||
is_peft_version,
|
||||
logging,
|
||||
set_adapter_layers,
|
||||
set_weights_and_activate_adapters,
|
||||
)
|
||||
from .lora_base import _fetch_state_dict
|
||||
from .lora_base import _fetch_state_dict, _func_optionally_disable_offloading
|
||||
from .unet_loader_utils import _maybe_expand_lora_scales
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_SET_ADAPTER_SCALE_FN_MAPPING = {
|
||||
@@ -140,27 +135,7 @@ class PeftAdapterMixin:
|
||||
tuple:
|
||||
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
||||
"""
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
|
||||
if _pipeline is not None and _pipeline.hf_device_map is None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
||||
if not is_model_cpu_offload:
|
||||
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
|
||||
if not is_sequential_cpu_offload:
|
||||
is_sequential_cpu_offload = (
|
||||
isinstance(component._hf_hook, AlignDevicesHook)
|
||||
or hasattr(component._hf_hook, "hooks")
|
||||
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
|
||||
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
||||
return _func_optionally_disable_offloading(_pipeline=_pipeline)
|
||||
|
||||
def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="transformer", **kwargs):
|
||||
r"""
|
||||
|
||||
@@ -595,10 +595,14 @@ def infer_diffusers_model_type(checkpoint):
|
||||
if any(
|
||||
g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
|
||||
):
|
||||
if checkpoint["img_in.weight"].shape[1] == 384:
|
||||
model_type = "flux-fill"
|
||||
if "model.diffusion_model.img_in.weight" in checkpoint:
|
||||
key = "model.diffusion_model.img_in.weight"
|
||||
else:
|
||||
key = "img_in.weight"
|
||||
|
||||
elif checkpoint["img_in.weight"].shape[1] == 128:
|
||||
if checkpoint[key].shape[1] == 384:
|
||||
model_type = "flux-fill"
|
||||
elif checkpoint[key].shape[1] == 128:
|
||||
model_type = "flux-depth"
|
||||
else:
|
||||
model_type = "flux-dev"
|
||||
|
||||
@@ -21,7 +21,6 @@ import safetensors
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from torch import nn
|
||||
|
||||
from ..models.embeddings import (
|
||||
ImageProjection,
|
||||
@@ -44,13 +43,11 @@ from ..utils import (
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
from .lora_base import _func_optionally_disable_offloading
|
||||
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
|
||||
from .utils import AttnProcsLayers
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@@ -400,27 +397,7 @@ class UNet2DConditionLoadersMixin:
|
||||
tuple:
|
||||
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
||||
"""
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
|
||||
if _pipeline is not None and _pipeline.hf_device_map is None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
||||
if not is_model_cpu_offload:
|
||||
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
|
||||
if not is_sequential_cpu_offload:
|
||||
is_sequential_cpu_offload = (
|
||||
isinstance(component._hf_hook, AlignDevicesHook)
|
||||
or hasattr(component._hf_hook, "hooks")
|
||||
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
|
||||
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
||||
return _func_optionally_disable_offloading(_pipeline=_pipeline)
|
||||
|
||||
def save_attn_procs(
|
||||
self,
|
||||
|
||||
@@ -718,10 +718,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
hf_quantizer = None
|
||||
|
||||
if hf_quantizer is not None:
|
||||
is_bnb_quantization_method = hf_quantizer.quantization_config.quant_method.value == "bitsandbytes"
|
||||
if is_bnb_quantization_method and device_map is not None:
|
||||
if device_map is not None:
|
||||
raise NotImplementedError(
|
||||
"Currently, `device_map` is automatically inferred for quantized bitsandbytes models. Support for providing `device_map` as an input will be added in the future."
|
||||
"Currently, providing `device_map` is not supported for quantized models. Providing `device_map` as an input will be added in the future."
|
||||
)
|
||||
|
||||
hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map)
|
||||
@@ -820,7 +819,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
revision=revision,
|
||||
subfolder=subfolder or "",
|
||||
)
|
||||
if hf_quantizer is not None and is_bnb_quantization_method:
|
||||
# TODO: https://github.com/huggingface/diffusers/issues/10013
|
||||
if hf_quantizer is not None:
|
||||
model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
|
||||
logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
|
||||
is_sharded = False
|
||||
|
||||
@@ -713,14 +713,16 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
condition_sequence_length = encoder_hidden_states.shape[1]
|
||||
sequence_length = latent_sequence_length + condition_sequence_length
|
||||
attention_mask = torch.zeros(
|
||||
batch_size, sequence_length, sequence_length, device=hidden_states.device, dtype=torch.bool
|
||||
) # [B, N, N]
|
||||
batch_size, sequence_length, device=hidden_states.device, dtype=torch.bool
|
||||
) # [B, N]
|
||||
|
||||
effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,]
|
||||
effective_sequence_length = latent_sequence_length + effective_condition_sequence_length
|
||||
|
||||
for i in range(batch_size):
|
||||
attention_mask[i, : effective_sequence_length[i], : effective_sequence_length[i]] = True
|
||||
attention_mask[i, : effective_sequence_length[i]] = True
|
||||
# [B, 1, 1, N], for broadcasting across attention heads
|
||||
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
|
||||
|
||||
# 4. Transformer blocks
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
@@ -23,7 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||
|
||||
from packaging import version
|
||||
|
||||
from ...utils import get_module_from_name, is_torch_available, is_torchao_available, logging
|
||||
from ...utils import get_module_from_name, is_torch_available, is_torch_version, is_torchao_available, logging
|
||||
from ..base import DiffusersQuantizer
|
||||
|
||||
|
||||
@@ -35,21 +35,28 @@ if is_torch_available():
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = (
|
||||
# At the moment, only int8 is supported for integer quantization dtypes.
|
||||
# In Torch 2.6, int1-int7 will be introduced, so this can be visited in the future
|
||||
# to support more quantization methods, such as intx_weight_only.
|
||||
torch.int8,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e5m2,
|
||||
torch.uint1,
|
||||
torch.uint2,
|
||||
torch.uint3,
|
||||
torch.uint4,
|
||||
torch.uint5,
|
||||
torch.uint6,
|
||||
torch.uint7,
|
||||
)
|
||||
if is_torch_version(">=", "2.5"):
|
||||
SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = (
|
||||
# At the moment, only int8 is supported for integer quantization dtypes.
|
||||
# In Torch 2.6, int1-int7 will be introduced, so this can be visited in the future
|
||||
# to support more quantization methods, such as intx_weight_only.
|
||||
torch.int8,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e5m2,
|
||||
torch.uint1,
|
||||
torch.uint2,
|
||||
torch.uint3,
|
||||
torch.uint4,
|
||||
torch.uint5,
|
||||
torch.uint6,
|
||||
torch.uint7,
|
||||
)
|
||||
else:
|
||||
SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = (
|
||||
torch.int8,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e5m2,
|
||||
)
|
||||
|
||||
if is_torchao_available():
|
||||
from torchao.quantization import quantize_
|
||||
@@ -125,7 +132,7 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
|
||||
def update_torch_dtype(self, torch_dtype):
|
||||
quant_type = self.quantization_config.quant_type
|
||||
|
||||
if quant_type.startswith("int"):
|
||||
if quant_type.startswith("int") or quant_type.startswith("uint"):
|
||||
if torch_dtype is not None and torch_dtype != torch.bfloat16:
|
||||
logger.warning(
|
||||
f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but "
|
||||
|
||||
@@ -100,7 +100,7 @@ from .import_utils import (
|
||||
is_xformers_available,
|
||||
requires_backends,
|
||||
)
|
||||
from .loading_utils import get_module_from_name, load_image, load_video
|
||||
from .loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video
|
||||
from .logging import get_logger
|
||||
from .outputs import BaseOutput
|
||||
from .peft_utils import (
|
||||
|
||||
@@ -148,3 +148,15 @@ def get_module_from_name(module, tensor_name: str) -> Tuple[Any, str]:
|
||||
module = new_module
|
||||
tensor_name = splits[-1]
|
||||
return module, tensor_name
|
||||
|
||||
|
||||
def get_submodule_by_name(root_module, module_path: str):
|
||||
current = root_module
|
||||
parts = module_path.split(".")
|
||||
for part in parts:
|
||||
if part.isdigit():
|
||||
idx = int(part)
|
||||
current = current[idx] # e.g., for nn.ModuleList or nn.Sequential
|
||||
else:
|
||||
current = getattr(current, part)
|
||||
return current
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
@@ -162,6 +163,105 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
)
|
||||
self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3))
|
||||
|
||||
def test_lora_expansion_works_for_absent_keys(self):
|
||||
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
# Modify the config to have a layer which won't be present in the second LoRA we will load.
|
||||
modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config)
|
||||
modified_denoiser_lora_config.target_modules.add("x_embedder")
|
||||
|
||||
pipe.transformer.add_adapter(modified_denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
|
||||
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertFalse(
|
||||
np.allclose(images_lora, output_no_lora, atol=1e-3, rtol=1e-3),
|
||||
"LoRA should lead to different results.",
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
|
||||
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
|
||||
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one")
|
||||
|
||||
# Modify the state dict to exclude "x_embedder" related LoRA params.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k}
|
||||
|
||||
pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="two")
|
||||
pipe.set_adapters(["one", "two"])
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
|
||||
images_lora_with_absent_keys = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
self.assertFalse(
|
||||
np.allclose(images_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3),
|
||||
"Different LoRAs should lead to different results.",
|
||||
)
|
||||
self.assertFalse(
|
||||
np.allclose(output_no_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3),
|
||||
"LoRA should lead to different results.",
|
||||
)
|
||||
|
||||
def test_lora_expansion_works_for_extra_keys(self):
|
||||
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
# Modify the config to have a layer which won't be present in the first LoRA we will load.
|
||||
modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config)
|
||||
modified_denoiser_lora_config.target_modules.add("x_embedder")
|
||||
|
||||
pipe.transformer.add_adapter(modified_denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
|
||||
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertFalse(
|
||||
np.allclose(images_lora, output_no_lora, atol=1e-3, rtol=1e-3),
|
||||
"LoRA should lead to different results.",
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
|
||||
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
|
||||
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
|
||||
pipe.unload_lora_weights()
|
||||
# Modify the state dict to exclude "x_embedder" related LoRA params.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k}
|
||||
pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="one")
|
||||
|
||||
# Load state dict with `x_embedder`.
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="two")
|
||||
|
||||
pipe.set_adapters(["one", "two"])
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
|
||||
images_lora_with_extra_keys = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
self.assertFalse(
|
||||
np.allclose(images_lora, images_lora_with_extra_keys, atol=1e-3, rtol=1e-3),
|
||||
"Different LoRAs should lead to different results.",
|
||||
)
|
||||
self.assertFalse(
|
||||
np.allclose(output_no_lora, images_lora_with_extra_keys, atol=1e-3, rtol=1e-3),
|
||||
"LoRA should lead to different results.",
|
||||
)
|
||||
|
||||
@unittest.skip("Not supported in Flux.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
@@ -558,6 +658,131 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2)
|
||||
self.assertFalse(np.allclose(original_output, lora_output, atol=1e-3, rtol=1e-3))
|
||||
|
||||
def test_lora_unload_with_parameter_expanded_shapes(self):
|
||||
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
# Change the transformer config to mimic a real use case.
|
||||
num_channels_without_control = 4
|
||||
transformer = FluxTransformer2DModel.from_config(
|
||||
components["transformer"].config, in_channels=num_channels_without_control
|
||||
).to(torch_device)
|
||||
self.assertTrue(
|
||||
transformer.config.in_channels == num_channels_without_control,
|
||||
f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}",
|
||||
)
|
||||
|
||||
# This should be initialized with a Flux pipeline variant that doesn't accept `control_image`.
|
||||
components["transformer"] = transformer
|
||||
pipe = FluxPipeline(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
control_image = inputs.pop("control_image")
|
||||
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
control_pipe = self.pipeline_class(**components)
|
||||
out_features, in_features = control_pipe.transformer.x_embedder.weight.shape
|
||||
rank = 4
|
||||
|
||||
dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False)
|
||||
dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
|
||||
lora_state_dict = {
|
||||
"transformer.x_embedder.lora_A.weight": dummy_lora_A.weight,
|
||||
"transformer.x_embedder.lora_B.weight": dummy_lora_B.weight,
|
||||
}
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
control_pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
|
||||
inputs["control_image"] = control_image
|
||||
lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
|
||||
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
|
||||
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
|
||||
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
|
||||
|
||||
control_pipe.unload_lora_weights(reset_to_overwritten_params=True)
|
||||
self.assertTrue(
|
||||
control_pipe.transformer.config.in_channels == num_channels_without_control,
|
||||
f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}",
|
||||
)
|
||||
loaded_pipe = FluxPipeline.from_pipe(control_pipe)
|
||||
self.assertTrue(
|
||||
loaded_pipe.transformer.config.in_channels == num_channels_without_control,
|
||||
f"Expected {num_channels_without_control} channels in the modified transformer but has {loaded_pipe.transformer.config.in_channels=}",
|
||||
)
|
||||
inputs.pop("control_image")
|
||||
unloaded_lora_out = loaded_pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertFalse(np.allclose(unloaded_lora_out, lora_out, rtol=1e-4, atol=1e-4))
|
||||
self.assertTrue(np.allclose(unloaded_lora_out, original_out, atol=1e-4, rtol=1e-4))
|
||||
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features)
|
||||
self.assertTrue(pipe.transformer.config.in_channels == in_features)
|
||||
|
||||
def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self):
|
||||
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
# Change the transformer config to mimic a real use case.
|
||||
num_channels_without_control = 4
|
||||
transformer = FluxTransformer2DModel.from_config(
|
||||
components["transformer"].config, in_channels=num_channels_without_control
|
||||
).to(torch_device)
|
||||
self.assertTrue(
|
||||
transformer.config.in_channels == num_channels_without_control,
|
||||
f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}",
|
||||
)
|
||||
|
||||
# This should be initialized with a Flux pipeline variant that doesn't accept `control_image`.
|
||||
components["transformer"] = transformer
|
||||
pipe = FluxPipeline(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
control_image = inputs.pop("control_image")
|
||||
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
control_pipe = self.pipeline_class(**components)
|
||||
out_features, in_features = control_pipe.transformer.x_embedder.weight.shape
|
||||
rank = 4
|
||||
|
||||
dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False)
|
||||
dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
|
||||
lora_state_dict = {
|
||||
"transformer.x_embedder.lora_A.weight": dummy_lora_A.weight,
|
||||
"transformer.x_embedder.lora_B.weight": dummy_lora_B.weight,
|
||||
}
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
control_pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
|
||||
inputs["control_image"] = control_image
|
||||
lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
|
||||
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
|
||||
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
|
||||
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
|
||||
|
||||
control_pipe.unload_lora_weights(reset_to_overwritten_params=False)
|
||||
self.assertTrue(
|
||||
control_pipe.transformer.config.in_channels == 2 * num_channels_without_control,
|
||||
f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}",
|
||||
)
|
||||
no_lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertFalse(np.allclose(no_lora_out, lora_out, rtol=1e-4, atol=1e-4))
|
||||
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2)
|
||||
self.assertTrue(pipe.transformer.config.in_channels == in_features * 2)
|
||||
|
||||
@unittest.skip("Not supported in Flux.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@@ -12,9 +12,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast
|
||||
|
||||
@@ -26,7 +29,11 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
floats_tensor,
|
||||
nightly,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_big_gpu_with_torch_cuda,
|
||||
require_peft_backend,
|
||||
require_torch_gpu,
|
||||
skip_mps,
|
||||
)
|
||||
|
||||
@@ -182,3 +189,69 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
@require_peft_backend
|
||||
@require_big_gpu_with_torch_cuda
|
||||
@pytest.mark.big_gpu_with_torch_cuda
|
||||
class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
|
||||
"""internal note: The integration slices were obtained on DGX.
|
||||
|
||||
torch: 2.5.1+cu124 with CUDA 12.5. Need the same setup for the
|
||||
assertions to pass.
|
||||
"""
|
||||
|
||||
num_inference_steps = 10
|
||||
seed = 0
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
model_id = "hunyuanvideo-community/HunyuanVideo"
|
||||
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
|
||||
model_id, subfolder="transformer", torch_dtype=torch.bfloat16
|
||||
)
|
||||
self.pipeline = HunyuanVideoPipeline.from_pretrained(
|
||||
model_id, transformer=transformer, torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_original_format_cseti(self):
|
||||
self.pipeline.load_lora_weights(
|
||||
"Cseti/HunyuanVideo-LoRA-Arcane_Jinx-v1", weight_name="csetiarcane-nfjinx-v1-6000.safetensors"
|
||||
)
|
||||
self.pipeline.fuse_lora()
|
||||
self.pipeline.unload_lora_weights()
|
||||
self.pipeline.vae.enable_tiling()
|
||||
|
||||
prompt = "CSETIARCANE. A cat walks on the grass, realistic"
|
||||
|
||||
out = self.pipeline(
|
||||
prompt=prompt,
|
||||
height=320,
|
||||
width=512,
|
||||
num_frames=9,
|
||||
num_inference_steps=self.num_inference_steps,
|
||||
output_type="np",
|
||||
generator=torch.manual_seed(self.seed),
|
||||
).frames[0]
|
||||
out = out.flatten()
|
||||
out_slice = np.concatenate((out[:8], out[-8:]))
|
||||
|
||||
# fmt: off
|
||||
expected_slice = np.array([0.1013, 0.1924, 0.0078, 0.1021, 0.1929, 0.0078, 0.1023, 0.1919, 0.7402, 0.104, 0.4482, 0.7354, 0.0925, 0.4382, 0.7275, 0.0815])
|
||||
# fmt: on
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
|
||||
|
||||
assert max_diff < 1e-3
|
||||
|
||||
@@ -20,6 +20,7 @@ import unittest
|
||||
import numpy as np
|
||||
import pytest
|
||||
import safetensors.torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel
|
||||
from diffusers.utils import is_accelerate_version, logging
|
||||
@@ -568,6 +569,27 @@ class SlowBnb4BitFluxTests(Base4bitTests):
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
|
||||
self.assertTrue(max_diff < 1e-3)
|
||||
|
||||
def test_lora_loading(self):
|
||||
self.pipeline_4bit.load_lora_weights(
|
||||
hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd"
|
||||
)
|
||||
self.pipeline_4bit.set_adapters("hyper-sd", adapter_weights=0.125)
|
||||
|
||||
output = self.pipeline_4bit(
|
||||
prompt=self.prompt,
|
||||
height=256,
|
||||
width=256,
|
||||
max_sequence_length=64,
|
||||
output_type="np",
|
||||
num_inference_steps=8,
|
||||
generator=torch.Generator().manual_seed(42),
|
||||
).images
|
||||
out_slice = output[0, -3:, -3:, -1].flatten()
|
||||
expected_slice = np.array([0.5347, 0.5342, 0.5283, 0.5093, 0.4988, 0.5093, 0.5044, 0.5015, 0.4946])
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
|
||||
self.assertTrue(max_diff < 1e-3)
|
||||
|
||||
|
||||
@slow
|
||||
class BaseBnb4BitSerializationTests(Base4bitTests):
|
||||
|
||||
@@ -131,8 +131,9 @@ class TorchAoTest(unittest.TestCase):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_dummy_components(self, quantization_config: TorchAoConfig):
|
||||
model_id = "hf-internal-testing/tiny-flux-pipe"
|
||||
def get_dummy_components(
|
||||
self, quantization_config: TorchAoConfig, model_id: str = "hf-internal-testing/tiny-flux-pipe"
|
||||
):
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="transformer",
|
||||
@@ -211,8 +212,8 @@ class TorchAoTest(unittest.TestCase):
|
||||
"timestep": timestep,
|
||||
}
|
||||
|
||||
def _test_quant_type(self, quantization_config: TorchAoConfig, expected_slice: List[float]):
|
||||
components = self.get_dummy_components(quantization_config)
|
||||
def _test_quant_type(self, quantization_config: TorchAoConfig, expected_slice: List[float], model_id: str):
|
||||
components = self.get_dummy_components(quantization_config, model_id)
|
||||
pipe = FluxPipeline(**components)
|
||||
pipe.to(device=torch_device)
|
||||
|
||||
@@ -223,44 +224,45 @@ class TorchAoTest(unittest.TestCase):
|
||||
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
|
||||
|
||||
def test_quantization(self):
|
||||
# fmt: off
|
||||
QUANTIZATION_TYPES_TO_TEST = [
|
||||
("int4wo", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6445, 0.4336, 0.4531, 0.5625])),
|
||||
("int4dq", np.array([0.4688, 0.5195, 0.5547, 0.418, 0.4414, 0.6406, 0.4336, 0.4531, 0.5625])),
|
||||
("int8wo", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
|
||||
("int8dq", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
|
||||
("uint4wo", np.array([0.4609, 0.5234, 0.5508, 0.4199, 0.4336, 0.6406, 0.4316, 0.4531, 0.5625])),
|
||||
("uint7wo", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
|
||||
]
|
||||
for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
|
||||
# fmt: off
|
||||
QUANTIZATION_TYPES_TO_TEST = [
|
||||
("int4wo", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6445, 0.4336, 0.4531, 0.5625])),
|
||||
("int4dq", np.array([0.4688, 0.5195, 0.5547, 0.418, 0.4414, 0.6406, 0.4336, 0.4531, 0.5625])),
|
||||
("int8wo", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
|
||||
("int8dq", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
|
||||
("uint4wo", np.array([0.4609, 0.5234, 0.5508, 0.4199, 0.4336, 0.6406, 0.4316, 0.4531, 0.5625])),
|
||||
("uint7wo", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
|
||||
]
|
||||
|
||||
if TorchAoConfig._is_cuda_capability_atleast_8_9():
|
||||
QUANTIZATION_TYPES_TO_TEST.extend([
|
||||
("float8wo_e5m2", np.array([0.4590, 0.5273, 0.5547, 0.4219, 0.4375, 0.6406, 0.4316, 0.4512, 0.5625])),
|
||||
("float8wo_e4m3", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6406, 0.4316, 0.4531, 0.5625])),
|
||||
# =====
|
||||
# The following lead to an internal torch error:
|
||||
# RuntimeError: mat2 shape (32x4 must be divisible by 16
|
||||
# Skip these for now; TODO(aryan): investigate later
|
||||
# ("float8dq_e4m3", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
|
||||
# ("float8dq_e4m3_tensor", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
|
||||
# =====
|
||||
# Cutlass fails to initialize for below
|
||||
# ("float8dq_e4m3_row", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
|
||||
# =====
|
||||
("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
|
||||
("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
|
||||
])
|
||||
# fmt: on
|
||||
if TorchAoConfig._is_cuda_capability_atleast_8_9():
|
||||
QUANTIZATION_TYPES_TO_TEST.extend([
|
||||
("float8wo_e5m2", np.array([0.4590, 0.5273, 0.5547, 0.4219, 0.4375, 0.6406, 0.4316, 0.4512, 0.5625])),
|
||||
("float8wo_e4m3", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6406, 0.4316, 0.4531, 0.5625])),
|
||||
# =====
|
||||
# The following lead to an internal torch error:
|
||||
# RuntimeError: mat2 shape (32x4 must be divisible by 16
|
||||
# Skip these for now; TODO(aryan): investigate later
|
||||
# ("float8dq_e4m3", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
|
||||
# ("float8dq_e4m3_tensor", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
|
||||
# =====
|
||||
# Cutlass fails to initialize for below
|
||||
# ("float8dq_e4m3_row", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
|
||||
# =====
|
||||
("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
|
||||
("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
|
||||
])
|
||||
# fmt: on
|
||||
|
||||
for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST:
|
||||
quant_kwargs = {}
|
||||
if quantization_name in ["uint4wo", "uint7wo"]:
|
||||
# The dummy flux model that we use has smaller dimensions. This imposes some restrictions on group_size here
|
||||
quant_kwargs.update({"group_size": 16})
|
||||
quantization_config = TorchAoConfig(
|
||||
quant_type=quantization_name, modules_to_not_convert=["x_embedder"], **quant_kwargs
|
||||
)
|
||||
self._test_quant_type(quantization_config, expected_slice)
|
||||
for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST:
|
||||
quant_kwargs = {}
|
||||
if quantization_name in ["uint4wo", "uint7wo"]:
|
||||
# The dummy flux model that we use has smaller dimensions. This imposes some restrictions on group_size here
|
||||
quant_kwargs.update({"group_size": 16})
|
||||
quantization_config = TorchAoConfig(
|
||||
quant_type=quantization_name, modules_to_not_convert=["x_embedder"], **quant_kwargs
|
||||
)
|
||||
self._test_quant_type(quantization_config, expected_slice, model_id)
|
||||
|
||||
def test_int4wo_quant_bfloat16_conversion(self):
|
||||
"""
|
||||
@@ -280,12 +282,14 @@ class TorchAoTest(unittest.TestCase):
|
||||
self.assertEqual(weight.quant_max, 15)
|
||||
|
||||
def test_device_map(self):
|
||||
# Note: We were not checking if the weight tensor's were AffineQuantizedTensor's before. If we did
|
||||
# it would have errored out. Now, we do. So, device_map basically never worked with or without
|
||||
# sharded checkpoints. This will need to be supported in the future (TODO(aryan))
|
||||
"""
|
||||
Test if the quantized model int4 weight-only is working properly with "auto" and custom device maps.
|
||||
The custom device map performs cpu/disk offloading as well. Also verifies that the device map is
|
||||
correctly set (in the `hf_device_map` attribute of the model).
|
||||
"""
|
||||
|
||||
custom_device_map_dict = {
|
||||
"time_text_embed": torch_device,
|
||||
"context_embedder": torch_device,
|
||||
@@ -297,48 +301,54 @@ class TorchAoTest(unittest.TestCase):
|
||||
}
|
||||
device_maps = ["auto", custom_device_map_dict]
|
||||
|
||||
inputs = self.get_dummy_tensor_inputs(torch_device)
|
||||
expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375])
|
||||
# inputs = self.get_dummy_tensor_inputs(torch_device)
|
||||
# expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375])
|
||||
|
||||
for device_map in device_maps:
|
||||
device_map_to_compare = {"": 0} if device_map == "auto" else device_map
|
||||
# device_map_to_compare = {"": 0} if device_map == "auto" else device_map
|
||||
|
||||
# Test non-sharded model
|
||||
with tempfile.TemporaryDirectory() as offload_folder:
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
|
||||
quantized_model = FluxTransformer2DModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-pipe",
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
device_map=device_map,
|
||||
torch_dtype=torch.bfloat16,
|
||||
offload_folder=offload_folder,
|
||||
)
|
||||
# Test non-sharded model - should work
|
||||
with self.assertRaises(NotImplementedError):
|
||||
with tempfile.TemporaryDirectory() as offload_folder:
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
|
||||
_ = FluxTransformer2DModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-pipe",
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
device_map=device_map,
|
||||
torch_dtype=torch.bfloat16,
|
||||
offload_folder=offload_folder,
|
||||
)
|
||||
|
||||
self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)
|
||||
# weight = quantized_model.transformer_blocks[0].ff.net[2].weight
|
||||
# self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)
|
||||
# self.assertTrue(isinstance(weight, AffineQuantizedTensor))
|
||||
|
||||
output = quantized_model(**inputs)[0]
|
||||
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
|
||||
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
|
||||
# output = quantized_model(**inputs)[0]
|
||||
# output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
|
||||
# self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
|
||||
|
||||
# Test sharded model
|
||||
with tempfile.TemporaryDirectory() as offload_folder:
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
|
||||
quantized_model = FluxTransformer2DModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-sharded",
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
device_map=device_map,
|
||||
torch_dtype=torch.bfloat16,
|
||||
offload_folder=offload_folder,
|
||||
)
|
||||
# Test sharded model - should not work
|
||||
with self.assertRaises(NotImplementedError):
|
||||
with tempfile.TemporaryDirectory() as offload_folder:
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
|
||||
_ = FluxTransformer2DModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-sharded",
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
device_map=device_map,
|
||||
torch_dtype=torch.bfloat16,
|
||||
offload_folder=offload_folder,
|
||||
)
|
||||
|
||||
self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)
|
||||
# weight = quantized_model.transformer_blocks[0].ff.net[2].weight
|
||||
# self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)
|
||||
# self.assertTrue(isinstance(weight, AffineQuantizedTensor))
|
||||
|
||||
output = quantized_model(**inputs)[0]
|
||||
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
|
||||
# output = quantized_model(**inputs)[0]
|
||||
# output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
|
||||
|
||||
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
|
||||
# self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
|
||||
|
||||
def test_modules_to_not_convert(self):
|
||||
quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"])
|
||||
@@ -404,43 +414,63 @@ class TorchAoTest(unittest.TestCase):
|
||||
@nightly
|
||||
def test_torch_compile(self):
|
||||
r"""Test that verifies if torch.compile works with torchao quantization."""
|
||||
quantization_config = TorchAoConfig("int8_weight_only")
|
||||
components = self.get_dummy_components(quantization_config)
|
||||
pipe = FluxPipeline(**components)
|
||||
pipe.to(device=torch_device, dtype=torch.bfloat16)
|
||||
for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
|
||||
quantization_config = TorchAoConfig("int8_weight_only")
|
||||
components = self.get_dummy_components(quantization_config, model_id=model_id)
|
||||
pipe = FluxPipeline(**components)
|
||||
pipe.to(device=torch_device)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
normal_output = pipe(**inputs)[0].flatten()[-32:]
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
normal_output = pipe(**inputs)[0].flatten()[-32:]
|
||||
|
||||
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True, dynamic=False)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
compile_output = pipe(**inputs)[0].flatten()[-32:]
|
||||
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True, dynamic=False)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
compile_output = pipe(**inputs)[0].flatten()[-32:]
|
||||
|
||||
# Note: Seems to require higher tolerance
|
||||
self.assertTrue(np.allclose(normal_output, compile_output, atol=1e-2, rtol=1e-3))
|
||||
# Note: Seems to require higher tolerance
|
||||
self.assertTrue(np.allclose(normal_output, compile_output, atol=1e-2, rtol=1e-3))
|
||||
|
||||
def test_memory_footprint(self):
|
||||
r"""
|
||||
A simple test to check if the model conversion has been done correctly by checking on the
|
||||
memory footprint of the converted model and the class type of the linear layers of the converted models
|
||||
"""
|
||||
transformer_int4wo = self.get_dummy_components(TorchAoConfig("int4wo"))["transformer"]
|
||||
transformer_int4wo_gs32 = self.get_dummy_components(TorchAoConfig("int4wo", group_size=32))["transformer"]
|
||||
transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"))["transformer"]
|
||||
transformer_bf16 = self.get_dummy_components(None)["transformer"]
|
||||
for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
|
||||
transformer_int4wo = self.get_dummy_components(TorchAoConfig("int4wo"), model_id=model_id)["transformer"]
|
||||
transformer_int4wo_gs32 = self.get_dummy_components(
|
||||
TorchAoConfig("int4wo", group_size=32), model_id=model_id
|
||||
)["transformer"]
|
||||
transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"), model_id=model_id)["transformer"]
|
||||
transformer_bf16 = self.get_dummy_components(None, model_id=model_id)["transformer"]
|
||||
|
||||
total_int4wo = get_model_size_in_bytes(transformer_int4wo)
|
||||
total_int4wo_gs32 = get_model_size_in_bytes(transformer_int4wo_gs32)
|
||||
total_int8wo = get_model_size_in_bytes(transformer_int8wo)
|
||||
total_bf16 = get_model_size_in_bytes(transformer_bf16)
|
||||
# Will not quantized all the layers by default due to the model weights shapes not being divisible by group_size=64
|
||||
for block in transformer_int4wo.transformer_blocks:
|
||||
self.assertTrue(isinstance(block.ff.net[2].weight, AffineQuantizedTensor))
|
||||
self.assertTrue(isinstance(block.ff_context.net[2].weight, AffineQuantizedTensor))
|
||||
|
||||
# Latter has smaller group size, so more groups -> more scales and zero points
|
||||
self.assertTrue(total_int4wo < total_int4wo_gs32)
|
||||
# int8 quantizes more layers compare to int4 with default group size
|
||||
self.assertTrue(total_int8wo < total_int4wo)
|
||||
# int4wo does not quantize too many layers because of default group size, but for the layers it does
|
||||
# there is additional overhead of scales and zero points
|
||||
self.assertTrue(total_bf16 < total_int4wo)
|
||||
# Will quantize all the linear layers except x_embedder
|
||||
for name, module in transformer_int4wo_gs32.named_modules():
|
||||
if isinstance(module, nn.Linear) and name not in ["x_embedder"]:
|
||||
self.assertTrue(isinstance(module.weight, AffineQuantizedTensor))
|
||||
|
||||
# Will quantize all the linear layers
|
||||
for module in transformer_int8wo.modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
self.assertTrue(isinstance(module.weight, AffineQuantizedTensor))
|
||||
|
||||
total_int4wo = get_model_size_in_bytes(transformer_int4wo)
|
||||
total_int4wo_gs32 = get_model_size_in_bytes(transformer_int4wo_gs32)
|
||||
total_int8wo = get_model_size_in_bytes(transformer_int8wo)
|
||||
total_bf16 = get_model_size_in_bytes(transformer_bf16)
|
||||
|
||||
# TODO: refactor to align with other quantization tests
|
||||
# Latter has smaller group size, so more groups -> more scales and zero points
|
||||
self.assertTrue(total_int4wo < total_int4wo_gs32)
|
||||
# int8 quantizes more layers compare to int4 with default group size
|
||||
self.assertTrue(total_int8wo < total_int4wo)
|
||||
# int4wo does not quantize too many layers because of default group size, but for the layers it does
|
||||
# there is additional overhead of scales and zero points
|
||||
self.assertTrue(total_bf16 < total_int4wo)
|
||||
|
||||
def test_wrong_config(self):
|
||||
with self.assertRaises(ValueError):
|
||||
@@ -500,6 +530,8 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
inputs = self.get_dummy_tensor_inputs(torch_device)
|
||||
output = quantized_model(**inputs)[0]
|
||||
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
|
||||
weight = quantized_model.transformer_blocks[0].ff.net[2].weight
|
||||
self.assertTrue(isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)))
|
||||
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
|
||||
|
||||
def _check_serialization_expected_slice(self, quant_method, quant_method_kwargs, expected_slice, device):
|
||||
@@ -508,8 +540,8 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
quantized_model.save_pretrained(tmp_dir, safe_serialization=False)
|
||||
loaded_quantized_model = FluxTransformer2DModel.from_pretrained(
|
||||
tmp_dir, torch_dtype=torch.bfloat16, device_map=torch_device, use_safetensors=False
|
||||
)
|
||||
tmp_dir, torch_dtype=torch.bfloat16, use_safetensors=False
|
||||
).to(device=torch_device)
|
||||
|
||||
inputs = self.get_dummy_tensor_inputs(torch_device)
|
||||
output = loaded_quantized_model(**inputs)[0]
|
||||
@@ -563,20 +595,25 @@ class SlowTorchAoTests(unittest.TestCase):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_dummy_components(self, quantization_config: TorchAoConfig):
|
||||
# This is just for convenience, so that we can modify it at one place for custom environments and locally testing
|
||||
cache_dir = None
|
||||
model_id = "black-forest-labs/FLUX.1-dev"
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16, cache_dir=cache_dir
|
||||
)
|
||||
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained(
|
||||
model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16
|
||||
model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16, cache_dir=cache_dir
|
||||
)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
|
||||
tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2")
|
||||
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer", cache_dir=cache_dir)
|
||||
tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2", cache_dir=cache_dir)
|
||||
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16, cache_dir=cache_dir)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
return {
|
||||
@@ -611,10 +648,12 @@ class SlowTorchAoTests(unittest.TestCase):
|
||||
pipe = FluxPipeline(**components)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
weight = pipe.transformer.transformer_blocks[0].ff.net[2].weight
|
||||
self.assertTrue(isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)))
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output = pipe(**inputs)[0].flatten()
|
||||
output_slice = np.concatenate((output[:16], output[-16:]))
|
||||
|
||||
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
|
||||
|
||||
def test_quantization(self):
|
||||
@@ -627,7 +666,7 @@ class SlowTorchAoTests(unittest.TestCase):
|
||||
if TorchAoConfig._is_cuda_capability_atleast_8_9():
|
||||
QUANTIZATION_TYPES_TO_TEST.extend([
|
||||
("float8wo_e4m3", np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])),
|
||||
("fp5_e3m1", np.array([0.0527, 0.0742, 0.1289, 0.0449, 0.0625, 0.1308, 0.0585, 0.0742, 0.1269, 0.0585, 0.0722, 0.1328, 0.0566, 0.0742, 0.1347, 0.0585, 0.3691, 0.7578, 0.5429, 0.4355, 0.7695, 0.5546, 0.4414, 0.7578, 0.5468, 0.4179, 0.7265, 0.5273, 0.3945, 0.6992, 0.5234, 0.4316])),
|
||||
("fp5_e3m1", np.array([0.0527, 0.0762, 0.1309, 0.0449, 0.0645, 0.1328, 0.0566, 0.0723, 0.125, 0.0566, 0.0703, 0.1328, 0.0566, 0.0742, 0.1348, 0.0566, 0.3633, 0.7617, 0.5273, 0.4277, 0.7891, 0.5469, 0.4375, 0.8008, 0.5586, 0.4336, 0.7383, 0.5156, 0.3906, 0.6992, 0.5156, 0.4375])),
|
||||
])
|
||||
# fmt: on
|
||||
|
||||
@@ -637,3 +676,125 @@ class SlowTorchAoTests(unittest.TestCase):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def test_serialization_int8wo(self):
|
||||
quantization_config = TorchAoConfig("int8wo")
|
||||
components = self.get_dummy_components(quantization_config)
|
||||
pipe = FluxPipeline(**components)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
weight = pipe.transformer.x_embedder.weight
|
||||
self.assertTrue(isinstance(weight, AffineQuantizedTensor))
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output = pipe(**inputs)[0].flatten()[:128]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipe.transformer.save_pretrained(tmp_dir, safe_serialization=False)
|
||||
pipe.remove_all_hooks()
|
||||
del pipe.transformer
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
tmp_dir, torch_dtype=torch.bfloat16, use_safetensors=False
|
||||
)
|
||||
pipe.transformer = transformer
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
weight = transformer.x_embedder.weight
|
||||
self.assertTrue(isinstance(weight, AffineQuantizedTensor))
|
||||
|
||||
loaded_output = pipe(**inputs)[0].flatten()[:128]
|
||||
# Seems to require higher tolerance depending on which machine it is being run.
|
||||
# A difference of 0.06 in normalized pixel space (-1 to 1), corresponds to a difference of
|
||||
# 0.06 / 2 * 255 = 7.65 in pixel space (0 to 255). On our CI runners, the difference is about 0.04,
|
||||
# on DGX it is 0.06, and on audace it is 0.037. So, we are using a tolerance of 0.06 here.
|
||||
self.assertTrue(np.allclose(output, loaded_output, atol=0.06))
|
||||
|
||||
def test_memory_footprint_int4wo(self):
|
||||
# The original checkpoints are in bf16 and about 24 GB
|
||||
expected_memory_in_gb = 6.0
|
||||
quantization_config = TorchAoConfig("int4wo")
|
||||
cache_dir = None
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
int4wo_memory_in_gb = get_model_size_in_bytes(transformer) / 1024**3
|
||||
self.assertTrue(int4wo_memory_in_gb < expected_memory_in_gb)
|
||||
|
||||
def test_memory_footprint_int8wo(self):
|
||||
# The original checkpoints are in bf16 and about 24 GB
|
||||
expected_memory_in_gb = 12.0
|
||||
quantization_config = TorchAoConfig("int8wo")
|
||||
cache_dir = None
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
int8wo_memory_in_gb = get_model_size_in_bytes(transformer) / 1024**3
|
||||
self.assertTrue(int8wo_memory_in_gb < expected_memory_in_gb)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torch_gpu
|
||||
@require_torchao_version_greater_or_equal("0.7.0")
|
||||
@slow
|
||||
@nightly
|
||||
class SlowTorchAoPreserializedModelTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_dummy_inputs(self, device: torch.device, seed: int = 0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator().manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "an astronaut riding a horse in space",
|
||||
"height": 512,
|
||||
"width": 512,
|
||||
"num_inference_steps": 20,
|
||||
"output_type": "np",
|
||||
"generator": generator,
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_transformer_int8wo(self):
|
||||
# fmt: off
|
||||
expected_slice = np.array([0.0566, 0.0781, 0.1426, 0.0488, 0.0684, 0.1504, 0.0625, 0.0781, 0.1445, 0.0625, 0.0781, 0.1562, 0.0547, 0.0723, 0.1484, 0.0566, 0.5703, 0.8867, 0.7266, 0.5742, 0.875, 0.7148, 0.5586, 0.875, 0.7148, 0.5547, 0.8633, 0.7109, 0.5469, 0.8398, 0.6992, 0.5703])
|
||||
# fmt: on
|
||||
|
||||
# This is just for convenience, so that we can modify it at one place for custom environments and locally testing
|
||||
cache_dir = None
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
"hf-internal-testing/FLUX.1-Dev-TorchAO-int8wo-transformer",
|
||||
torch_dtype=torch.bfloat16,
|
||||
use_safetensors=False,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16, cache_dir=cache_dir
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
# Verify that all linear layer weights are quantized
|
||||
for name, module in pipe.transformer.named_modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
self.assertTrue(isinstance(module.weight, AffineQuantizedTensor))
|
||||
|
||||
# Verify outputs match expected slice
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output = pipe(**inputs)[0].flatten()
|
||||
output_slice = np.concatenate((output[:16], output[-16:]))
|
||||
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
FluxTransformer2DModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
require_torch_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@require_torch_accelerator
|
||||
class FluxTransformer2DModelSingleFileTests(unittest.TestCase):
|
||||
model_class = FluxTransformer2DModel
|
||||
ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
|
||||
alternate_keys_ckpt_paths = ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"]
|
||||
|
||||
repo_id = "black-forest-labs/FLUX.1-dev"
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_single_file_components(self):
|
||||
model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
|
||||
model_single_file = self.model_class.from_single_file(self.ckpt_path)
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
|
||||
for param_name, param_value in model_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
model.config[param_name] == param_value
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
|
||||
def test_checkpoint_loading(self):
|
||||
for ckpt_path in self.alternate_keys_ckpt_paths:
|
||||
torch.cuda.empty_cache()
|
||||
model = self.model_class.from_single_file(ckpt_path)
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
Reference in New Issue
Block a user