Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b3ab9c3f24 | |||
| a0119978f8 | |||
| a4f7080bb4 |
@@ -88,7 +88,7 @@ PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixar
|
||||
| FaithDiff Stable Diffusion XL Pipeline | Implementation of [(CVPR 2025) FaithDiff: Unleashing Diffusion Priors for Faithful Image Super-resolutionUnleashing Diffusion Priors for Faithful Image Super-resolution](https://huggingface.co/papers/2411.18824) - FaithDiff is a faithful image super-resolution method that leverages latent diffusion models by actively adapting the diffusion prior and jointly fine-tuning its components (encoder and diffusion model) with an alignment module to ensure high fidelity and structural consistency. | [FaithDiff Stable Diffusion XL Pipeline](#faithdiff-stable-diffusion-xl-pipeline) | [](https://huggingface.co/jychen9811/FaithDiff) | [Junyang Chen, Jinshan Pan, Jiangxin Dong, IMAG Lab, (Adapted by Eliseu Silva)](https://github.com/JyChen9811/FaithDiff) |
|
||||
| Stable Diffusion 3 InstructPix2Pix Pipeline | Implementation of Stable Diffusion 3 InstructPix2Pix Pipeline | [Stable Diffusion 3 InstructPix2Pix Pipeline](#stable-diffusion-3-instructpix2pix-pipeline) | [](https://huggingface.co/BleachNick/SD3_UltraEdit_freeform) [](https://huggingface.co/CaptainZZZ/sd3-instructpix2pix) | [Jiayu Zhang](https://github.com/xduzhangjiayu) and [Haozhe Zhao](https://github.com/HaozheZhao)|
|
||||
| Flux Kontext multiple images | A modified version of the `FluxKontextPipeline` that supports calling Flux Kontext with multiple reference images.| [Flux Kontext multiple input Pipeline](#flux-kontext-multiple-images) | - | [Net-Mist](https://github.com/Net-Mist) |
|
||||
| Flux Fill ControlNet Pipeline | A modified version of the `FluxFillPipeline` and `FluxControlNetInpaintPipeline` that supports Controlnet with Flux Fill model.| [Flux Fill ControlNet Pipeline](#Flux-Fill-ControlNet-Pipeline) | - | [pratim4dasude](https://github.com/pratim4dasude) |
|
||||
|
||||
|
||||
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
|
||||
|
||||
@@ -5527,106 +5527,3 @@ images = pipe(
|
||||
).images
|
||||
images[0].save("pizzeria.png")
|
||||
```
|
||||
|
||||
# Flux Fill ControlNet Pipeline
|
||||
|
||||
This implementation of Flux Fill + ControlNet Inpaint combines the fill-style masked editing of FLUX.1-Fill-dev with full ControlNet conditioning. The base image is processed through the Fill model while the ControlNet receives the corresponding conditioning input (depth, canny, pose, etc.), and both outputs are fused during denoising to guide structure and composition.
|
||||
|
||||
While FLUX.1-Fill-dev is designed for mask-based edits, it was not originally trained to operate jointly with ControlNet. In practice, this combined setup works well for structured inpainting tasks, though results may vary depending on the conditioning strength and the alignment between the mask and the control input.
|
||||
|
||||
## Example Usage
|
||||
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import (
|
||||
FluxControlNetModel,
|
||||
FluxPriorReduxPipeline,
|
||||
)
|
||||
from diffusers.utils import load_image
|
||||
|
||||
# NEW PIPELINE (updated name)
|
||||
from pipline_flux_fill_controlnet_Inpaint import FluxControlNetFillInpaintPipeline
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
# Models
|
||||
base_model = "black-forest-labs/FLUX.1-Fill-dev"
|
||||
controlnet_model = "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro-2.0"
|
||||
prior_model = "black-forest-labs/FLUX.1-Redux-dev"
|
||||
|
||||
# Load ControlNet
|
||||
controlnet = FluxControlNetModel.from_pretrained(
|
||||
controlnet_model,
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
|
||||
# Load Fill + ControlNet Pipeline
|
||||
fill_pipe = FluxControlNetFillInpaintPipeline.from_pretrained(
|
||||
base_model,
|
||||
controlnet=controlnet,
|
||||
torch_dtype=dtype,
|
||||
).to(device)
|
||||
|
||||
# OPTIONAL FP8
|
||||
# fill_pipe.transformer.enable_layerwise_casting(
|
||||
# storage_dtype=torch.float8_e4m3fn,
|
||||
# compute_dtype=torch.bfloat16
|
||||
# )
|
||||
|
||||
# OPTIONAL Prior Redux
|
||||
#pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(
|
||||
# prior_model,
|
||||
# torch_dtype=dtype,
|
||||
#).to(device)
|
||||
|
||||
# Inputs
|
||||
|
||||
# combined_image = load_image("person_input.png")
|
||||
|
||||
|
||||
# 1. Prior conditioning
|
||||
#prior_out = pipe_prior_redux(
|
||||
# image=cloth_image,
|
||||
# prompt=cloth_prompt,
|
||||
#)
|
||||
|
||||
# 2. Fill Inpaint with ControlNet
|
||||
|
||||
# canny (0), tile (1), depth (2), blur (3), pose (4), gray (5), low quality (6).
|
||||
|
||||
img = load_image(r"imgs/background.jpg")
|
||||
mask = load_image(r"imgs/mask.png")
|
||||
|
||||
control_image_depth = load_image(r"imgs/dog_depth _2.png")
|
||||
|
||||
result = fill_pipe(
|
||||
prompt="a dog on a bench",
|
||||
image=img,
|
||||
mask_image=mask,
|
||||
|
||||
control_image=control_image_depth,
|
||||
control_mode=[2], # union mode
|
||||
control_guidance_start=0.0,
|
||||
control_guidance_end=0.8,
|
||||
controlnet_conditioning_scale=0.9,
|
||||
|
||||
height=1024,
|
||||
width=1024,
|
||||
|
||||
strength=1.0,
|
||||
guidance_scale=50.0,
|
||||
num_inference_steps=60,
|
||||
max_sequence_length=512,
|
||||
|
||||
# **prior_out,
|
||||
)
|
||||
|
||||
# result.images[0].save("flux_fill_controlnet_inpaint.png")
|
||||
|
||||
from datetime import datetime
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
result.images[0].save(f"flux_fill_controlnet_inpaint_depth{timestamp}.jpg")
|
||||
```
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -79,14 +79,15 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
methods the library implements for all schedulers such as loading and saving.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`, defaults to `1000`):
|
||||
num_train_timesteps (`int`, defaults to 1000):
|
||||
The number of diffusion steps to train the model.
|
||||
beta_start (`float`, defaults to `0.0001`):
|
||||
beta_start (`float`, defaults to 0.0001):
|
||||
The starting `beta` value of inference.
|
||||
beta_end (`float`, defaults to `0.02`):
|
||||
beta_end (`float`, defaults to 0.02):
|
||||
The final `beta` value.
|
||||
beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
|
||||
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model.
|
||||
beta_schedule (`str`, defaults to `"linear"`):
|
||||
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
||||
trained_betas (`np.ndarray`, *optional*):
|
||||
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
||||
skip_prk_steps (`bool`, defaults to `False`):
|
||||
@@ -96,13 +97,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
|
||||
there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
|
||||
otherwise it uses the alpha value at step 0.
|
||||
prediction_type (`"epsilon"` or `"v_prediction"`, defaults to `"epsilon"`):
|
||||
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
||||
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process)
|
||||
or `v_prediction` (see section 2.4 of [Imagen Video](https://huggingface.co/papers/2210.02303) paper).
|
||||
timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"leading"`):
|
||||
or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf)
|
||||
paper).
|
||||
timestep_spacing (`str`, defaults to `"leading"`):
|
||||
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
||||
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||||
steps_offset (`int`, defaults to `0`):
|
||||
steps_offset (`int`, defaults to 0):
|
||||
An offset added to the inference steps, as required by some model families.
|
||||
"""
|
||||
|
||||
@@ -115,12 +117,12 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
skip_prk_steps: bool = False,
|
||||
set_alpha_to_one: bool = False,
|
||||
prediction_type: Literal["epsilon", "v_prediction"] = "epsilon",
|
||||
timestep_spacing: Literal["linspace", "leading", "trailing"] = "leading",
|
||||
prediction_type: str = "epsilon",
|
||||
timestep_spacing: str = "leading",
|
||||
steps_offset: int = 0,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
@@ -162,7 +164,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.plms_timesteps = None
|
||||
self.timesteps = None
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None) -> None:
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
@@ -241,7 +243,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.Tensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
return_dict (`bool`, defaults to `True`):
|
||||
return_dict (`bool`):
|
||||
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
|
||||
|
||||
Returns:
|
||||
@@ -274,13 +276,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.Tensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
return_dict (`bool`, defaults to `True`):
|
||||
return_dict (`bool`):
|
||||
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
|
||||
|
||||
Returns:
|
||||
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
||||
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
|
||||
tuple is returned where the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
if self.num_inference_steps is None:
|
||||
raise ValueError(
|
||||
@@ -332,13 +335,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.Tensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
return_dict (`bool`, defaults to `True`):
|
||||
return_dict (`bool`):
|
||||
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
|
||||
|
||||
Returns:
|
||||
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
||||
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
|
||||
tuple is returned where the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
if self.num_inference_steps is None:
|
||||
raise ValueError(
|
||||
@@ -399,27 +403,19 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return sample
|
||||
|
||||
def _get_prev_sample(
|
||||
self, sample: torch.Tensor, timestep: int, prev_timestep: int, model_output: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute the previous sample x_(t-δ) from the current sample x_t using formula (9) from the [PNDM
|
||||
paper](https://huggingface.co/papers/2202.09778).
|
||||
def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
|
||||
# See formula (9) of PNDM paper https://huggingface.co/papers/2202.09778
|
||||
# this function computes x_(t−δ) using the formula of (9)
|
||||
# Note that x_t needs to be added to both sides of the equation
|
||||
|
||||
Args:
|
||||
sample (`torch.Tensor`):
|
||||
The current sample x_t.
|
||||
timestep (`int`):
|
||||
The current timestep t.
|
||||
prev_timestep (`int`):
|
||||
The previous timestep (t-δ).
|
||||
model_output (`torch.Tensor`):
|
||||
The model output e_θ(x_t, t).
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The previous sample x_(t-δ).
|
||||
"""
|
||||
# Notation (<variable name> -> <name in paper>
|
||||
# alpha_prod_t -> α_t
|
||||
# alpha_prod_t_prev -> α_(t−δ)
|
||||
# beta_prod_t -> (1 - α_t)
|
||||
# beta_prod_t_prev -> (1 - α_(t−δ))
|
||||
# sample -> x_t
|
||||
# model_output -> e_θ(x_t, t)
|
||||
# prev_sample -> x_(t−δ)
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
@@ -493,5 +489,5 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self) -> int:
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
@@ -13,12 +13,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
|
||||
from diffusers import AuraFlowPipeline, AuraFlowTransformer2DModel, FlowMatchEulerDiscreteScheduler
|
||||
from diffusers import (
|
||||
AuraFlowPipeline,
|
||||
AuraFlowTransformer2DModel,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
)
|
||||
|
||||
from ..testing_utils import (
|
||||
floats_tensor,
|
||||
@@ -36,7 +40,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
class TestAuraFlowLoRA(PeftLoraLoaderMixinTests):
|
||||
class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = AuraFlowPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
@@ -99,34 +103,34 @@ class TestAuraFlowLoRA(PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
@pytest.mark.skip("Not supported in AuraFlow.")
|
||||
@unittest.skip("Not supported in AuraFlow.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Not supported in AuraFlow.")
|
||||
@unittest.skip("Not supported in AuraFlow.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Not supported in AuraFlow.")
|
||||
@unittest.skip("Not supported in AuraFlow.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -13,9 +13,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
@@ -38,7 +39,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
class TestCogVideoXLoRA(PeftLoraLoaderMixinTests):
|
||||
class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = CogVideoXPipeline
|
||||
scheduler_cls = CogVideoXDPMScheduler
|
||||
scheduler_kwargs = {"timestep_spacing": "trailing"}
|
||||
@@ -118,59 +119,54 @@ class TestCogVideoXLoRA(PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3, pipe=pipe)
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
|
||||
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
|
||||
|
||||
def test_lora_scale_kwargs_match_fusion(self, base_pipe_output):
|
||||
super().test_lora_scale_kwargs_match_fusion(
|
||||
base_pipe_output=base_pipe_output, expected_atol=9e-3, expected_rtol=9e-3
|
||||
)
|
||||
def test_lora_scale_kwargs_match_fusion(self):
|
||||
super().test_lora_scale_kwargs_match_fusion(expected_atol=9e-3, expected_rtol=9e-3)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"offload_type, use_stream",
|
||||
[("block_level", True), ("leaf_level", False)],
|
||||
)
|
||||
@parameterized.expand([("block_level", True), ("leaf_level", False)])
|
||||
@require_torch_accelerator
|
||||
def test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname, pipe):
|
||||
def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
|
||||
# TODO: We don't run the (leaf_level, True) test here that is enabled for other models.
|
||||
# The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338
|
||||
super()._test_group_offloading_inference_denoiser(offload_type, use_stream, tmpdirname, pipe)
|
||||
super()._test_group_offloading_inference_denoiser(offload_type, use_stream)
|
||||
|
||||
@pytest.mark.skip("Not supported in CogVideoX.")
|
||||
@unittest.skip("Not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Not supported in CogVideoX.")
|
||||
@unittest.skip("Not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Not supported in CogVideoX.")
|
||||
@unittest.skip("Not supported in CogVideoX.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Not supported in CogVideoX.")
|
||||
@unittest.skip("Not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
|
||||
pass
|
||||
|
||||
@@ -13,9 +13,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoTokenizer, GlmModel
|
||||
|
||||
from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
|
||||
@@ -25,6 +28,7 @@ from ..testing_utils import (
|
||||
require_peft_backend,
|
||||
require_torch_accelerator,
|
||||
skip_mps,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
@@ -43,7 +47,7 @@ class TokenizerWrapper:
|
||||
|
||||
@require_peft_backend
|
||||
@skip_mps
|
||||
class TestCogView4LoRA(PeftLoraLoaderMixinTests):
|
||||
class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = CogView4Pipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
@@ -109,50 +113,72 @@ class TestCogView4LoRA(PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3)
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
|
||||
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"offload_type, use_stream",
|
||||
[("block_level", True), ("leaf_level", False)],
|
||||
)
|
||||
def test_simple_inference_save_pretrained(self):
|
||||
"""
|
||||
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
|
||||
"""
|
||||
components, _, _ = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pipe.save_pretrained(tmpdirname)
|
||||
|
||||
pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
|
||||
pipe_from_pretrained.to(torch_device)
|
||||
|
||||
images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
|
||||
"Loading from saved checkpoints should give same results.",
|
||||
)
|
||||
|
||||
@parameterized.expand([("block_level", True), ("leaf_level", False)])
|
||||
@require_torch_accelerator
|
||||
def test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname, pipe):
|
||||
def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
|
||||
# TODO: We don't run the (leaf_level, True) test here that is enabled for other models.
|
||||
# The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338
|
||||
super()._test_group_offloading_inference_denoiser(offload_type, use_stream, tmpdirname, pipe)
|
||||
super()._test_group_offloading_inference_denoiser(offload_type, use_stream)
|
||||
|
||||
@pytest.mark.skip("Not supported in CogView4.")
|
||||
@unittest.skip("Not supported in CogView4.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Not supported in CogView4.")
|
||||
@unittest.skip("Not supported in CogView4.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Not supported in CogView4.")
|
||||
@unittest.skip("Not supported in CogView4.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
+344
-244
@@ -16,11 +16,13 @@ import copy
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
|
||||
@@ -44,12 +46,14 @@ from ..testing_utils import (
|
||||
|
||||
if is_peft_available():
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
|
||||
sys.path.append(".")
|
||||
from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set
|
||||
|
||||
from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
class TestFluxLoRA(PeftLoraLoaderMixinTests):
|
||||
class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = FluxPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
@@ -111,134 +115,165 @@ class TestFluxLoRA(PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
def test_with_alpha_in_state_dict(self, tmpdirname, pipe):
|
||||
_, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
def test_with_alpha_in_state_dict(self):
|
||||
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
pipe.transformer.add_adapter(denoiser_lora_config)
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer"
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
|
||||
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
|
||||
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
|
||||
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
|
||||
|
||||
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
# modify the state dict to have alpha values following
|
||||
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA/blob/main/jon_snow.safetensors
|
||||
state_dict_with_alpha = safetensors.torch.load_file(
|
||||
os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")
|
||||
)
|
||||
alpha_dict = {}
|
||||
for k, v in state_dict_with_alpha.items():
|
||||
if "transformer" in k and "to_k" in k and ("lora_A" in k):
|
||||
alpha_dict[f"{k}.alpha"] = float(torch.randint(10, 100, size=()))
|
||||
state_dict_with_alpha.update(alpha_dict)
|
||||
# modify the state dict to have alpha values following
|
||||
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA/blob/main/jon_snow.safetensors
|
||||
state_dict_with_alpha = safetensors.torch.load_file(
|
||||
os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")
|
||||
)
|
||||
alpha_dict = {}
|
||||
for k, v in state_dict_with_alpha.items():
|
||||
# only do for `transformer` and for the k projections -- should be enough to test.
|
||||
if "transformer" in k and "to_k" in k and "lora_A" in k:
|
||||
alpha_dict[f"{k}.alpha"] = float(torch.randint(10, 100, size=()))
|
||||
state_dict_with_alpha.update(alpha_dict)
|
||||
|
||||
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(state_dict_with_alpha)
|
||||
images_lora_with_alpha = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
assert np.allclose(images_lora, images_lora_from_pretrained, atol=0.001, rtol=0.001), (
|
||||
"Loading from saved checkpoints should give same results."
|
||||
)
|
||||
assert not np.allclose(images_lora_with_alpha, images_lora, atol=0.001, rtol=0.001)
|
||||
|
||||
def test_lora_expansion_works_for_absent_keys(self, base_pipe_output, tmpdirname, pipe):
|
||||
_, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
self.assertTrue(
|
||||
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
|
||||
"Loading from saved checkpoints should give same results.",
|
||||
)
|
||||
self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3))
|
||||
|
||||
def test_lora_expansion_works_for_absent_keys(self):
|
||||
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.get_base_pipe_output()
|
||||
|
||||
# Modify the config to have a layer which won't be present in the second LoRA we will load.
|
||||
modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config)
|
||||
modified_denoiser_lora_config.target_modules.add("x_embedder")
|
||||
|
||||
pipe.transformer.add_adapter(modified_denoiser_lora_config)
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer"
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
|
||||
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
assert not np.allclose(images_lora, base_pipe_output, atol=0.001, rtol=0.001), (
|
||||
"LoRA should lead to different results."
|
||||
self.assertFalse(
|
||||
np.allclose(images_lora, output_no_lora, atol=1e-3, rtol=1e-3),
|
||||
"LoRA should lead to different results.",
|
||||
)
|
||||
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
|
||||
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
|
||||
|
||||
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
|
||||
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
|
||||
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one")
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
lora_state_dict_without_xembedder = {k: v for (k, v) in lora_state_dict.items() if "x_embedder" not in k}
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one")
|
||||
|
||||
# Modify the state dict to exclude "x_embedder" related LoRA params.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k}
|
||||
|
||||
pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="two")
|
||||
pipe.set_adapters(["one", "two"])
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer"
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
|
||||
images_lora_with_absent_keys = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
assert not np.allclose(images_lora, images_lora_with_absent_keys, atol=0.001, rtol=0.001), (
|
||||
"Different LoRAs should lead to different results."
|
||||
|
||||
self.assertFalse(
|
||||
np.allclose(images_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3),
|
||||
"Different LoRAs should lead to different results.",
|
||||
)
|
||||
assert not np.allclose(base_pipe_output, images_lora_with_absent_keys, atol=0.001, rtol=0.001), (
|
||||
"LoRA should lead to different results."
|
||||
self.assertFalse(
|
||||
np.allclose(output_no_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3),
|
||||
"LoRA should lead to different results.",
|
||||
)
|
||||
|
||||
def test_lora_expansion_works_for_extra_keys(self, base_pipe_output, tmpdirname, pipe):
|
||||
_, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
def test_lora_expansion_works_for_extra_keys(self):
|
||||
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
output_no_lora = self.get_base_pipe_output()
|
||||
|
||||
# Modify the config to have a layer which won't be present in the first LoRA we will load.
|
||||
modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config)
|
||||
modified_denoiser_lora_config.target_modules.add("x_embedder")
|
||||
|
||||
pipe.transformer.add_adapter(modified_denoiser_lora_config)
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer"
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
|
||||
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
assert not np.allclose(images_lora, base_pipe_output, atol=0.001, rtol=0.001), (
|
||||
"LoRA should lead to different results."
|
||||
self.assertFalse(
|
||||
np.allclose(images_lora, output_no_lora, atol=1e-3, rtol=1e-3),
|
||||
"LoRA should lead to different results.",
|
||||
)
|
||||
|
||||
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
|
||||
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
|
||||
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
|
||||
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
|
||||
|
||||
pipe.unload_lora_weights()
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
lora_state_dict_without_xembedder = {k: v for (k, v) in lora_state_dict.items() if "x_embedder" not in k}
|
||||
pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="one")
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="two")
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
|
||||
pipe.unload_lora_weights()
|
||||
# Modify the state dict to exclude "x_embedder" related LoRA params.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k}
|
||||
pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="one")
|
||||
|
||||
# Load state dict with `x_embedder`.
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="two")
|
||||
|
||||
pipe.set_adapters(["one", "two"])
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer"
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
|
||||
images_lora_with_extra_keys = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
assert not np.allclose(images_lora, images_lora_with_extra_keys, atol=0.001, rtol=0.001), (
|
||||
"Different LoRAs should lead to different results."
|
||||
|
||||
self.assertFalse(
|
||||
np.allclose(images_lora, images_lora_with_extra_keys, atol=1e-3, rtol=1e-3),
|
||||
"Different LoRAs should lead to different results.",
|
||||
)
|
||||
assert not np.allclose(base_pipe_output, images_lora_with_extra_keys, atol=0.001, rtol=0.001), (
|
||||
"LoRA should lead to different results."
|
||||
self.assertFalse(
|
||||
np.allclose(output_no_lora, images_lora_with_extra_keys, atol=1e-3, rtol=1e-3),
|
||||
"LoRA should lead to different results.",
|
||||
)
|
||||
|
||||
@pytest.mark.skip("Not supported in Flux.")
|
||||
@unittest.skip("Not supported in Flux.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Not supported in Flux.")
|
||||
@unittest.skip("Not supported in Flux.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Not supported in Flux.")
|
||||
@unittest.skip("Not supported in Flux.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Not supported in Flux.")
|
||||
@unittest.skip("Not supported in Flux.")
|
||||
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestFluxControlLoRA(PeftLoraLoaderMixinTests):
|
||||
class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = FluxControlPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
@@ -303,7 +338,12 @@ class TestFluxControlLoRA(PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
def test_with_norm_in_state_dict(self, pipe):
|
||||
def test_with_norm_in_state_dict(self):
|
||||
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
@@ -324,32 +364,39 @@ class TestFluxControlLoRA(PeftLoraLoaderMixinTests):
|
||||
pipe.load_lora_weights(norm_state_dict)
|
||||
lora_load_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
assert (
|
||||
self.assertTrue(
|
||||
"The provided state dict contains normalization layers in addition to LoRA layers"
|
||||
in cap_logger.out
|
||||
)
|
||||
assert len(pipe.transformer._transformer_norm_layers) > 0
|
||||
self.assertTrue(len(pipe.transformer._transformer_norm_layers) > 0)
|
||||
|
||||
pipe.unload_lora_weights()
|
||||
lora_unload_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
assert pipe.transformer._transformer_norm_layers is None
|
||||
assert np.allclose(original_output, lora_unload_output, atol=1e-05, rtol=1e-05)
|
||||
assert not np.allclose(original_output, lora_load_output, atol=1e-06, rtol=1e-06), (
|
||||
f"{norm_layer} is tested"
|
||||
self.assertTrue(pipe.transformer._transformer_norm_layers is None)
|
||||
self.assertTrue(np.allclose(original_output, lora_unload_output, atol=1e-5, rtol=1e-5))
|
||||
self.assertFalse(
|
||||
np.allclose(original_output, lora_load_output, atol=1e-6, rtol=1e-6), f"{norm_layer} is tested"
|
||||
)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
for key in list(norm_state_dict.keys()):
|
||||
norm_state_dict[key.replace("norm", "norm_k_something_random")] = norm_state_dict.pop(key)
|
||||
pipe.load_lora_weights(norm_state_dict)
|
||||
assert "Unsupported keys found in state dict when trying to load normalization layers" in cap_logger.out
|
||||
|
||||
def test_lora_parameter_expanded_shapes(self, pipe):
|
||||
self.assertTrue(
|
||||
"Unsupported keys found in state dict when trying to load normalization layers" in cap_logger.out
|
||||
)
|
||||
|
||||
def test_lora_parameter_expanded_shapes(self):
|
||||
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
@@ -358,21 +405,24 @@ class TestFluxControlLoRA(PeftLoraLoaderMixinTests):
|
||||
transformer = FluxTransformer2DModel.from_config(
|
||||
components["transformer"].config, in_channels=num_channels_without_control
|
||||
).to(torch_device)
|
||||
assert transformer.config.in_channels == num_channels_without_control, (
|
||||
f"Expected {num_channels_without_control} channels in the modified transformer but has transformer.config.in_channels={transformer.config.in_channels!r}"
|
||||
self.assertTrue(
|
||||
transformer.config.in_channels == num_channels_without_control,
|
||||
f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}",
|
||||
)
|
||||
|
||||
original_transformer_state_dict = pipe.transformer.state_dict()
|
||||
x_embedder_weight = original_transformer_state_dict.pop("x_embedder.weight")
|
||||
incompatible_keys = transformer.load_state_dict(original_transformer_state_dict, strict=False)
|
||||
assert "x_embedder.weight" in incompatible_keys.missing_keys, (
|
||||
"Could not find x_embedder.weight in the missing keys."
|
||||
self.assertTrue(
|
||||
"x_embedder.weight" in incompatible_keys.missing_keys,
|
||||
"Could not find x_embedder.weight in the missing keys.",
|
||||
)
|
||||
|
||||
transformer.x_embedder.weight.data.copy_(x_embedder_weight[..., :num_channels_without_control])
|
||||
pipe.transformer = transformer
|
||||
|
||||
out_features, in_features = pipe.transformer.x_embedder.weight.shape
|
||||
rank = 4
|
||||
|
||||
dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False)
|
||||
dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
|
||||
lora_state_dict = {
|
||||
@@ -381,13 +431,15 @@ class TestFluxControlLoRA(PeftLoraLoaderMixinTests):
|
||||
}
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
|
||||
lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
assert not np.allclose(original_out, lora_out, rtol=0.0001, atol=0.0001)
|
||||
assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features
|
||||
assert pipe.transformer.config.in_channels == 2 * in_features
|
||||
assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")
|
||||
|
||||
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
|
||||
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
|
||||
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
|
||||
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
|
||||
|
||||
# Testing opposite direction where the LoRA params are zero-padded.
|
||||
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
@@ -402,13 +454,15 @@ class TestFluxControlLoRA(PeftLoraLoaderMixinTests):
|
||||
}
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
|
||||
lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
assert not np.allclose(original_out, lora_out, rtol=0.0001, atol=0.0001)
|
||||
assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features
|
||||
assert pipe.transformer.config.in_channels == 2 * in_features
|
||||
assert "The following LoRA modules were zero padded to match the state dict of" in cap_logger.out
|
||||
|
||||
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
|
||||
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
|
||||
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
|
||||
self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out)
|
||||
|
||||
def test_normal_lora_with_expanded_lora_raises_error(self):
|
||||
# Test the following situation. Load a regular LoRA (such as the ones trained on Flux.1-Dev). And then
|
||||
@@ -440,28 +494,32 @@ class TestFluxControlLoRA(PeftLoraLoaderMixinTests):
|
||||
}
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
|
||||
assert pipe.get_active_adapters() == ["adapter-1"]
|
||||
assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features
|
||||
assert pipe.transformer.config.in_channels == 2 * in_features
|
||||
assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")
|
||||
(_, _, inputs) = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
self.assertTrue(pipe.get_active_adapters() == ["adapter-1"])
|
||||
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
|
||||
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
|
||||
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
normal_lora_A = torch.nn.Linear(in_features, rank, bias=False)
|
||||
normal_lora_B = torch.nn.Linear(rank, out_features, bias=False)
|
||||
lora_state_dict = {
|
||||
"transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
|
||||
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
|
||||
}
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-2")
|
||||
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
|
||||
assert "The following LoRA modules were zero padded to match the state dict of" in cap_logger.out
|
||||
assert pipe.get_active_adapters() == ["adapter-2"]
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out)
|
||||
self.assertTrue(pipe.get_active_adapters() == ["adapter-2"])
|
||||
|
||||
lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
assert not np.allclose(lora_output, lora_output_2, atol=0.001, rtol=0.001)
|
||||
self.assertFalse(np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3))
|
||||
|
||||
# Test the opposite case where the first lora has the correct input features and the second lora has expanded input features.
|
||||
# This should raise a runtime error on input shapes being incompatible.
|
||||
@@ -482,24 +540,32 @@ class TestFluxControlLoRA(PeftLoraLoaderMixinTests):
|
||||
|
||||
out_features, in_features = pipe.transformer.x_embedder.weight.shape
|
||||
rank = 4
|
||||
|
||||
lora_state_dict = {
|
||||
"transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
|
||||
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
|
||||
}
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
|
||||
assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features
|
||||
assert pipe.transformer.config.in_channels == in_features
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features)
|
||||
self.assertTrue(pipe.transformer.config.in_channels == in_features)
|
||||
|
||||
lora_state_dict = {
|
||||
"transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight,
|
||||
"transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight,
|
||||
}
|
||||
|
||||
# We should check for input shapes being incompatible here. But because above mentioned issue is
|
||||
# not a supported use case, and because of the PEFT renaming, we will currently have a shape
|
||||
# mismatch error.
|
||||
with pytest.raises(RuntimeError, match="size mismatch for x_embedder.lora_A.adapter-2.weight"):
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-2")
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"size mismatch for x_embedder.lora_A.adapter-2.weight",
|
||||
pipe.load_lora_weights,
|
||||
lora_state_dict,
|
||||
"adapter-2",
|
||||
)
|
||||
|
||||
def test_fuse_expanded_lora_with_regular_lora(self):
|
||||
# This test checks if it works when a lora with expanded shapes (like control loras) but
|
||||
@@ -531,7 +597,7 @@ class TestFluxControlLoRA(PeftLoraLoaderMixinTests):
|
||||
"transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight,
|
||||
}
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
@@ -544,44 +610,54 @@ class TestFluxControlLoRA(PeftLoraLoaderMixinTests):
|
||||
}
|
||||
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-2")
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
|
||||
lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
pipe.set_adapters(["adapter-1", "adapter-2"], [1.0, 1.0])
|
||||
lora_output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
assert not np.allclose(lora_output, lora_output_2, atol=0.001, rtol=0.001)
|
||||
assert not np.allclose(lora_output, lora_output_3, atol=0.001, rtol=0.001)
|
||||
assert not np.allclose(lora_output_2, lora_output_3, atol=0.001, rtol=0.001)
|
||||
|
||||
self.assertFalse(np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3))
|
||||
self.assertFalse(np.allclose(lora_output, lora_output_3, atol=1e-3, rtol=1e-3))
|
||||
self.assertFalse(np.allclose(lora_output_2, lora_output_3, atol=1e-3, rtol=1e-3))
|
||||
|
||||
pipe.fuse_lora(lora_scale=1.0, adapter_names=["adapter-1", "adapter-2"])
|
||||
lora_output_4 = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
assert np.allclose(lora_output_3, lora_output_4, atol=0.001, rtol=0.001)
|
||||
self.assertTrue(np.allclose(lora_output_3, lora_output_4, atol=1e-3, rtol=1e-3))
|
||||
|
||||
def test_load_regular_lora(self, base_pipe_output, pipe):
|
||||
def test_load_regular_lora(self):
|
||||
# This test checks if a regular lora (think of one trained on Flux.1 Dev for example) can be loaded
|
||||
# into the transformer with more input channels than Flux.1 Dev, for example. Some examples of those
|
||||
# transformers include Flux Fill, Flux Control, etc.
|
||||
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
out_features, in_features = pipe.transformer.x_embedder.weight.shape
|
||||
rank = 4
|
||||
in_features = in_features // 2
|
||||
in_features = in_features // 2 # to mimic the Flux.1-Dev LoRA.
|
||||
normal_lora_A = torch.nn.Linear(in_features, rank, bias=False)
|
||||
normal_lora_B = torch.nn.Linear(rank, out_features, bias=False)
|
||||
lora_state_dict = {
|
||||
"transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
|
||||
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
|
||||
}
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
logger.setLevel(logging.INFO)
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
|
||||
lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
assert "The following LoRA modules were zero padded to match the state dict of" in cap_logger.out
|
||||
assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2
|
||||
assert not np.allclose(base_pipe_output, lora_output, atol=0.001, rtol=0.001)
|
||||
|
||||
self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out)
|
||||
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2)
|
||||
self.assertFalse(np.allclose(original_output, lora_output, atol=1e-3, rtol=1e-3))
|
||||
|
||||
def test_lora_unload_with_parameter_expanded_shapes(self):
|
||||
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
@@ -594,8 +670,9 @@ class TestFluxControlLoRA(PeftLoraLoaderMixinTests):
|
||||
transformer = FluxTransformer2DModel.from_config(
|
||||
components["transformer"].config, in_channels=num_channels_without_control
|
||||
).to(torch_device)
|
||||
assert transformer.config.in_channels == num_channels_without_control, (
|
||||
f"Expected {num_channels_without_control} channels in the modified transformer but has transformer.config.in_channels={transformer.config.in_channels!r}"
|
||||
self.assertTrue(
|
||||
transformer.config.in_channels == num_channels_without_control,
|
||||
f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}",
|
||||
)
|
||||
|
||||
# This should be initialized with a Flux pipeline variant that doesn't accept `control_image`.
|
||||
@@ -620,31 +697,33 @@ class TestFluxControlLoRA(PeftLoraLoaderMixinTests):
|
||||
}
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
control_pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
|
||||
inputs["control_image"] = control_image
|
||||
lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
assert not np.allclose(original_out, lora_out, rtol=0.0001, atol=0.0001)
|
||||
assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features
|
||||
assert pipe.transformer.config.in_channels == 2 * in_features
|
||||
assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")
|
||||
|
||||
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
|
||||
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
|
||||
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
|
||||
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
|
||||
|
||||
control_pipe.unload_lora_weights(reset_to_overwritten_params=True)
|
||||
assert control_pipe.transformer.config.in_channels == num_channels_without_control, (
|
||||
f"Expected {num_channels_without_control} channels in the modified transformer but has control_pipe.transformer.config.in_channels={control_pipe.transformer.config.in_channels!r}"
|
||||
self.assertTrue(
|
||||
control_pipe.transformer.config.in_channels == num_channels_without_control,
|
||||
f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}",
|
||||
)
|
||||
|
||||
loaded_pipe = FluxPipeline.from_pipe(control_pipe)
|
||||
assert loaded_pipe.transformer.config.in_channels == num_channels_without_control, (
|
||||
f"Expected {num_channels_without_control} channels in the modified transformer but has loaded_pipe.transformer.config.in_channels={loaded_pipe.transformer.config.in_channels!r}"
|
||||
self.assertTrue(
|
||||
loaded_pipe.transformer.config.in_channels == num_channels_without_control,
|
||||
f"Expected {num_channels_without_control} channels in the modified transformer but has {loaded_pipe.transformer.config.in_channels=}",
|
||||
)
|
||||
|
||||
inputs.pop("control_image")
|
||||
unloaded_lora_out = loaded_pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
assert not np.allclose(unloaded_lora_out, lora_out, rtol=0.0001, atol=0.0001)
|
||||
assert np.allclose(unloaded_lora_out, original_out, atol=0.0001, rtol=0.0001)
|
||||
assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features
|
||||
assert pipe.transformer.config.in_channels == in_features
|
||||
|
||||
self.assertFalse(np.allclose(unloaded_lora_out, lora_out, rtol=1e-4, atol=1e-4))
|
||||
self.assertTrue(np.allclose(unloaded_lora_out, original_out, atol=1e-4, rtol=1e-4))
|
||||
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features)
|
||||
self.assertTrue(pipe.transformer.config.in_channels == in_features)
|
||||
|
||||
def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self):
|
||||
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
@@ -652,12 +731,14 @@ class TestFluxControlLoRA(PeftLoraLoaderMixinTests):
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
# Change the transformer config to mimic a real use case.
|
||||
num_channels_without_control = 4
|
||||
transformer = FluxTransformer2DModel.from_config(
|
||||
components["transformer"].config, in_channels=num_channels_without_control
|
||||
).to(torch_device)
|
||||
assert transformer.config.in_channels == num_channels_without_control, (
|
||||
f"Expected {num_channels_without_control} channels in the modified transformer but has transformer.config.in_channels={transformer.config.in_channels!r}"
|
||||
self.assertTrue(
|
||||
transformer.config.in_channels == num_channels_without_control,
|
||||
f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}",
|
||||
)
|
||||
|
||||
# This should be initialized with a Flux pipeline variant that doesn't accept `control_image`.
|
||||
@@ -682,38 +763,40 @@ class TestFluxControlLoRA(PeftLoraLoaderMixinTests):
|
||||
}
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
control_pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
|
||||
inputs["control_image"] = control_image
|
||||
lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
assert not np.allclose(original_out, lora_out, rtol=0.0001, atol=0.0001)
|
||||
assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features
|
||||
assert pipe.transformer.config.in_channels == 2 * in_features
|
||||
assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")
|
||||
|
||||
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
|
||||
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
|
||||
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
|
||||
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
|
||||
|
||||
control_pipe.unload_lora_weights(reset_to_overwritten_params=False)
|
||||
assert control_pipe.transformer.config.in_channels == 2 * num_channels_without_control, (
|
||||
f"Expected {num_channels_without_control} channels in the modified transformer but has control_pipe.transformer.config.in_channels={control_pipe.transformer.config.in_channels!r}"
|
||||
self.assertTrue(
|
||||
control_pipe.transformer.config.in_channels == 2 * num_channels_without_control,
|
||||
f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}",
|
||||
)
|
||||
|
||||
no_lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
assert not np.allclose(no_lora_out, lora_out, rtol=0.0001, atol=0.0001)
|
||||
assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2
|
||||
assert pipe.transformer.config.in_channels == in_features * 2
|
||||
|
||||
@pytest.mark.skip("Not supported in Flux.")
|
||||
self.assertFalse(np.allclose(no_lora_out, lora_out, rtol=1e-4, atol=1e-4))
|
||||
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2)
|
||||
self.assertTrue(pipe.transformer.config.in_channels == in_features * 2)
|
||||
|
||||
@unittest.skip("Not supported in Flux.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Not supported in Flux.")
|
||||
@unittest.skip("Not supported in Flux.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Not supported in Flux.")
|
||||
@unittest.skip("Not supported in Flux.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Not supported in Flux.")
|
||||
@unittest.skip("Not supported in Flux.")
|
||||
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
|
||||
pass
|
||||
|
||||
@@ -723,7 +806,7 @@ class TestFluxControlLoRA(PeftLoraLoaderMixinTests):
|
||||
@require_torch_accelerator
|
||||
@require_peft_backend
|
||||
@require_big_accelerator
|
||||
class TestFluxLoRAIntegration:
|
||||
class FluxLoRAIntegrationTests(unittest.TestCase):
|
||||
"""internal note: The integration slices were obtained on audace.
|
||||
|
||||
torch: 2.6.0.dev20241006+cu124 with CUDA 12.5. Need the same setup for the
|
||||
@@ -733,27 +816,33 @@ class TestFluxLoRAIntegration:
|
||||
num_inference_steps = 10
|
||||
seed = 0
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def pipeline(self):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to(
|
||||
torch_device
|
||||
)
|
||||
try:
|
||||
yield pipe
|
||||
finally:
|
||||
del pipe
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_flux_the_last_ben(self, pipeline):
|
||||
pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors")
|
||||
pipeline.fuse_lora()
|
||||
pipeline.unload_lora_weights()
|
||||
pipeline = pipeline.to(torch_device)
|
||||
self.pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
|
||||
del self.pipeline
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_flux_the_last_ben(self):
|
||||
self.pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors")
|
||||
self.pipeline.fuse_lora()
|
||||
self.pipeline.unload_lora_weights()
|
||||
# Instead of calling `enable_model_cpu_offload()`, we do a accelerator placement here because the CI
|
||||
# run supports it. We have about 34GB RAM in the CI runner which kills the test when run with
|
||||
# `enable_model_cpu_offload()`. We repeat this for the other tests, too.
|
||||
self.pipeline = self.pipeline.to(torch_device)
|
||||
|
||||
prompt = "jon snow eating pizza with ketchup"
|
||||
out = pipeline(
|
||||
|
||||
out = self.pipeline(
|
||||
prompt,
|
||||
num_inference_steps=self.num_inference_steps,
|
||||
guidance_scale=4.0,
|
||||
@@ -762,57 +851,71 @@ class TestFluxLoRAIntegration:
|
||||
).images
|
||||
out_slice = out[0, -3:, -3:, -1].flatten()
|
||||
expected_slice = np.array([0.1855, 0.1855, 0.1836, 0.1855, 0.1836, 0.1875, 0.1777, 0.1758, 0.2246])
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
|
||||
assert max_diff < 0.001
|
||||
|
||||
def test_flux_kohya(self, pipeline):
|
||||
pipeline.load_lora_weights("Norod78/brain-slug-flux")
|
||||
pipeline.fuse_lora()
|
||||
pipeline.unload_lora_weights()
|
||||
pipeline = pipeline.to(torch_device)
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
|
||||
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_flux_kohya(self):
|
||||
self.pipeline.load_lora_weights("Norod78/brain-slug-flux")
|
||||
self.pipeline.fuse_lora()
|
||||
self.pipeline.unload_lora_weights()
|
||||
self.pipeline = self.pipeline.to(torch_device)
|
||||
|
||||
prompt = "The cat with a brain slug earring"
|
||||
out = pipeline(
|
||||
out = self.pipeline(
|
||||
prompt,
|
||||
num_inference_steps=self.num_inference_steps,
|
||||
guidance_scale=4.5,
|
||||
output_type="np",
|
||||
generator=torch.manual_seed(self.seed),
|
||||
).images
|
||||
|
||||
out_slice = out[0, -3:, -3:, -1].flatten()
|
||||
expected_slice = np.array([0.6367, 0.6367, 0.6328, 0.6367, 0.6328, 0.6289, 0.6367, 0.6328, 0.6484])
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
|
||||
assert max_diff < 0.001
|
||||
|
||||
def test_flux_kohya_with_text_encoder(self, pipeline):
|
||||
pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors")
|
||||
pipeline.fuse_lora()
|
||||
pipeline.unload_lora_weights()
|
||||
pipeline = pipeline.to(torch_device)
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
|
||||
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_flux_kohya_with_text_encoder(self):
|
||||
self.pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors")
|
||||
self.pipeline.fuse_lora()
|
||||
self.pipeline.unload_lora_weights()
|
||||
self.pipeline = self.pipeline.to(torch_device)
|
||||
|
||||
prompt = "optimus is cleaning the house with broomstick"
|
||||
out = pipeline(
|
||||
out = self.pipeline(
|
||||
prompt,
|
||||
num_inference_steps=self.num_inference_steps,
|
||||
guidance_scale=4.5,
|
||||
output_type="np",
|
||||
generator=torch.manual_seed(self.seed),
|
||||
).images
|
||||
|
||||
out_slice = out[0, -3:, -3:, -1].flatten()
|
||||
expected_slice = np.array([0.4023, 0.4023, 0.4023, 0.3965, 0.3984, 0.3965, 0.3926, 0.3906, 0.4219])
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
|
||||
assert max_diff < 0.001
|
||||
|
||||
def test_flux_kohya_embedders_conversion(self, pipeline):
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_flux_kohya_embedders_conversion(self):
|
||||
"""Test that embedders load without throwing errors"""
|
||||
pipeline.load_lora_weights("rockerBOO/flux-bpo-po-lora")
|
||||
pipeline.unload_lora_weights()
|
||||
self.pipeline.load_lora_weights("rockerBOO/flux-bpo-po-lora")
|
||||
self.pipeline.unload_lora_weights()
|
||||
|
||||
assert True
|
||||
|
||||
def test_flux_xlabs(self):
|
||||
self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors")
|
||||
self.pipeline.fuse_lora()
|
||||
self.pipeline.unload_lora_weights()
|
||||
self.pipeline = self.pipeline.to(torch_device)
|
||||
|
||||
def test_flux_xlabs(self, pipeline):
|
||||
pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors")
|
||||
pipeline.fuse_lora()
|
||||
pipeline.unload_lora_weights()
|
||||
pipeline = pipeline.to(torch_device)
|
||||
prompt = "A blue jay standing on a large basket of rainbow macarons, disney style"
|
||||
out = pipeline(
|
||||
|
||||
out = self.pipeline(
|
||||
prompt,
|
||||
num_inference_steps=self.num_inference_steps,
|
||||
guidance_scale=3.5,
|
||||
@@ -820,17 +923,23 @@ class TestFluxLoRAIntegration:
|
||||
generator=torch.manual_seed(self.seed),
|
||||
).images
|
||||
out_slice = out[0, -3:, -3:, -1].flatten()
|
||||
expected_slice = np.array([0.3965, 0.418, 0.4434, 0.4082, 0.4375, 0.459, 0.4141, 0.4375, 0.498])
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
|
||||
assert max_diff < 0.001
|
||||
expected_slice = np.array([0.3965, 0.4180, 0.4434, 0.4082, 0.4375, 0.4590, 0.4141, 0.4375, 0.4980])
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
|
||||
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_flux_xlabs_load_lora_with_single_blocks(self):
|
||||
self.pipeline.load_lora_weights(
|
||||
"salinasr/test_xlabs_flux_lora_with_singleblocks", weight_name="lora.safetensors"
|
||||
)
|
||||
self.pipeline.fuse_lora()
|
||||
self.pipeline.unload_lora_weights()
|
||||
self.pipeline.enable_model_cpu_offload()
|
||||
|
||||
def test_flux_xlabs_load_lora_with_single_blocks(self, pipeline):
|
||||
pipeline.load_lora_weights("salinasr/test_xlabs_flux_lora_with_singleblocks", weight_name="lora.safetensors")
|
||||
pipeline.fuse_lora()
|
||||
pipeline.unload_lora_weights()
|
||||
pipeline.enable_model_cpu_offload()
|
||||
prompt = "a wizard mouse playing chess"
|
||||
out = pipeline(
|
||||
|
||||
out = self.pipeline(
|
||||
prompt,
|
||||
num_inference_steps=self.num_inference_steps,
|
||||
guidance_scale=3.5,
|
||||
@@ -842,43 +951,40 @@ class TestFluxLoRAIntegration:
|
||||
[0.04882812, 0.04101562, 0.04882812, 0.03710938, 0.02929688, 0.02734375, 0.0234375, 0.01757812, 0.0390625]
|
||||
)
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
|
||||
assert max_diff < 0.001
|
||||
|
||||
assert max_diff < 1e-3
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_accelerator
|
||||
@require_peft_backend
|
||||
@require_big_accelerator
|
||||
class TestFluxControlLoRAIntegration:
|
||||
class FluxControlLoRAIntegrationTests(unittest.TestCase):
|
||||
num_inference_steps = 10
|
||||
seed = 0
|
||||
prompt = "A robot made of exotic candies and chocolates of different kinds."
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def pipeline(self):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to(
|
||||
torch_device
|
||||
)
|
||||
try:
|
||||
yield pipe
|
||||
finally:
|
||||
del pipe
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"lora_ckpt_id",
|
||||
[
|
||||
"black-forest-labs/FLUX.1-Canny-dev-lora",
|
||||
"black-forest-labs/FLUX.1-Depth-dev-lora",
|
||||
],
|
||||
)
|
||||
def test_lora(self, pipeline, lora_ckpt_id):
|
||||
pipeline.load_lora_weights(lora_ckpt_id)
|
||||
pipeline.fuse_lora()
|
||||
pipeline.unload_lora_weights()
|
||||
self.pipeline = FluxControlPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
|
||||
).to(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@parameterized.expand(["black-forest-labs/FLUX.1-Canny-dev-lora", "black-forest-labs/FLUX.1-Depth-dev-lora"])
|
||||
def test_lora(self, lora_ckpt_id):
|
||||
self.pipeline.load_lora_weights(lora_ckpt_id)
|
||||
self.pipeline.fuse_lora()
|
||||
self.pipeline.unload_lora_weights()
|
||||
|
||||
if "Canny" in lora_ckpt_id:
|
||||
control_image = load_image(
|
||||
@@ -889,7 +995,7 @@ class TestFluxControlLoRAIntegration:
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/depth_condition_image.png"
|
||||
)
|
||||
|
||||
image = pipeline(
|
||||
image = self.pipeline(
|
||||
prompt=self.prompt,
|
||||
control_image=control_image,
|
||||
height=1024,
|
||||
@@ -910,18 +1016,12 @@ class TestFluxControlLoRAIntegration:
|
||||
|
||||
assert max_diff < 1e-3
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"lora_ckpt_id",
|
||||
[
|
||||
"black-forest-labs/FLUX.1-Canny-dev-lora",
|
||||
"black-forest-labs/FLUX.1-Depth-dev-lora",
|
||||
],
|
||||
)
|
||||
def test_lora_with_turbo(self, pipeline, lora_ckpt_id):
|
||||
pipeline.load_lora_weights(lora_ckpt_id)
|
||||
pipeline.load_lora_weights("ByteDance/Hyper-SD", weight_name="Hyper-FLUX.1-dev-8steps-lora.safetensors")
|
||||
pipeline.fuse_lora()
|
||||
pipeline.unload_lora_weights()
|
||||
@parameterized.expand(["black-forest-labs/FLUX.1-Canny-dev-lora", "black-forest-labs/FLUX.1-Depth-dev-lora"])
|
||||
def test_lora_with_turbo(self, lora_ckpt_id):
|
||||
self.pipeline.load_lora_weights(lora_ckpt_id)
|
||||
self.pipeline.load_lora_weights("ByteDance/Hyper-SD", weight_name="Hyper-FLUX.1-dev-8steps-lora.safetensors")
|
||||
self.pipeline.fuse_lora()
|
||||
self.pipeline.unload_lora_weights()
|
||||
|
||||
if "Canny" in lora_ckpt_id:
|
||||
control_image = load_image(
|
||||
|
||||
@@ -14,9 +14,9 @@
|
||||
|
||||
import gc
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast
|
||||
|
||||
@@ -48,7 +48,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
@require_peft_backend
|
||||
@skip_mps
|
||||
class TestHunyuanVideoLoRA(PeftLoraLoaderMixinTests):
|
||||
class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = HunyuanVideoPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
@@ -149,41 +149,46 @@ class TestHunyuanVideoLoRA(PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3)
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
|
||||
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
|
||||
|
||||
@pytest.mark.skip("Not supported in HunyuanVideo.")
|
||||
# TODO(aryan): Fix the following test
|
||||
@unittest.skip("This test fails with an error I haven't been able to debug yet.")
|
||||
def test_simple_inference_save_pretrained(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Not supported in HunyuanVideo.")
|
||||
@unittest.skip("Not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Not supported in HunyuanVideo.")
|
||||
@unittest.skip("Not supported in HunyuanVideo.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -192,7 +197,7 @@ class TestHunyuanVideoLoRA(PeftLoraLoaderMixinTests):
|
||||
@require_torch_accelerator
|
||||
@require_peft_backend
|
||||
@require_big_accelerator
|
||||
class TestHunyuanVideoLoRAIntegration:
|
||||
class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
|
||||
"""internal note: The integration slices were obtained on DGX.
|
||||
|
||||
torch: 2.5.1+cu124 with CUDA 12.5. Need the same setup for the
|
||||
@@ -202,8 +207,9 @@ class TestHunyuanVideoLoRAIntegration:
|
||||
num_inference_steps = 10
|
||||
seed = 0
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def pipeline(self):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@@ -211,27 +217,27 @@ class TestHunyuanVideoLoRAIntegration:
|
||||
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
|
||||
model_id, subfolder="transformer", torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16).to(
|
||||
torch_device
|
||||
)
|
||||
try:
|
||||
yield pipe
|
||||
finally:
|
||||
del pipe
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
self.pipeline = HunyuanVideoPipeline.from_pretrained(
|
||||
model_id, transformer=transformer, torch_dtype=torch.float16
|
||||
).to(torch_device)
|
||||
|
||||
def test_original_format_cseti(self, pipeline):
|
||||
pipeline.load_lora_weights(
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_original_format_cseti(self):
|
||||
self.pipeline.load_lora_weights(
|
||||
"Cseti/HunyuanVideo-LoRA-Arcane_Jinx-v1", weight_name="csetiarcane-nfjinx-v1-6000.safetensors"
|
||||
)
|
||||
pipeline.fuse_lora()
|
||||
pipeline.unload_lora_weights()
|
||||
pipeline.vae.enable_tiling()
|
||||
self.pipeline.fuse_lora()
|
||||
self.pipeline.unload_lora_weights()
|
||||
self.pipeline.vae.enable_tiling()
|
||||
|
||||
prompt = "CSETIARCANE. A cat walks on the grass, realistic"
|
||||
|
||||
out = pipeline(
|
||||
out = self.pipeline(
|
||||
prompt=prompt,
|
||||
height=320,
|
||||
width=512,
|
||||
|
||||
@@ -13,8 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
@@ -34,7 +34,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
class TestLTXVideoLoRA(PeftLoraLoaderMixinTests):
|
||||
class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = LTXPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
@@ -108,40 +108,40 @@ class TestLTXVideoLoRA(PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3)
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
|
||||
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
|
||||
|
||||
@pytest.mark.skip("Not supported in LTXVideo.")
|
||||
@unittest.skip("Not supported in LTXVideo.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Not supported in LTXVideo.")
|
||||
@unittest.skip("Not supported in LTXVideo.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Not supported in LTXVideo.")
|
||||
@unittest.skip("Not supported in LTXVideo.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -35,7 +36,7 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
class TestLumina2LoRA(PeftLoraLoaderMixinTests):
|
||||
class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = Lumina2Pipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
@@ -100,35 +101,35 @@ class TestLumina2LoRA(PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
@pytest.mark.skip("Not supported in Lumina2.")
|
||||
@unittest.skip("Not supported in Lumina2.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Not supported in Lumina2.")
|
||||
@unittest.skip("Not supported in Lumina2.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Not supported in Lumina2.")
|
||||
@unittest.skip("Not supported in Lumina2.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -138,17 +139,20 @@ class TestLumina2LoRA(PeftLoraLoaderMixinTests):
|
||||
reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
|
||||
strict=False,
|
||||
)
|
||||
def test_lora_fuse_nan(self, pipe):
|
||||
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
def test_lora_fuse_nan(self):
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser."
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
|
||||
# corrupt one LoRA weight with `inf` values
|
||||
with torch.no_grad():
|
||||
@@ -162,4 +166,4 @@ class TestLumina2LoRA(PeftLoraLoaderMixinTests):
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
|
||||
out = pipe(**inputs)[0]
|
||||
|
||||
assert np.isnan(out).all()
|
||||
self.assertTrue(np.isnan(out).all())
|
||||
|
||||
@@ -13,8 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
@@ -34,7 +34,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
@require_peft_backend
|
||||
@skip_mps
|
||||
class TestMochiLoRA(PeftLoraLoaderMixinTests):
|
||||
class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = MochiPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
@@ -99,44 +99,44 @@ class TestMochiLoRA(PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3)
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
|
||||
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
|
||||
|
||||
@pytest.mark.skip("Not supported in Mochi.")
|
||||
@unittest.skip("Not supported in Mochi.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Not supported in Mochi.")
|
||||
@unittest.skip("Not supported in Mochi.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Not supported in Mochi.")
|
||||
@unittest.skip("Not supported in Mochi.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Not supported in CogVideoX.")
|
||||
@unittest.skip("Not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
|
||||
pass
|
||||
|
||||
@@ -13,8 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
|
||||
|
||||
@@ -34,7 +34,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
class TestQwenImageLoRA(PeftLoraLoaderMixinTests):
|
||||
class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = QwenImagePipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
@@ -96,34 +96,34 @@ class TestQwenImageLoRA(PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
@pytest.mark.skip("Not supported in Qwen Image.")
|
||||
@unittest.skip("Not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Not supported in Qwen Image.")
|
||||
@unittest.skip("Not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Not supported in Qwen Image.")
|
||||
@unittest.skip("Not supported in Qwen Image.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -13,8 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import Gemma2Model, GemmaTokenizer
|
||||
|
||||
@@ -29,7 +29,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
class TestSanaLoRA(PeftLoraLoaderMixinTests):
|
||||
class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = SanaPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {"shift": 7.0}
|
||||
@@ -105,38 +105,38 @@ class TestSanaLoRA(PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
@pytest.mark.skip("Not supported in SANA.")
|
||||
@unittest.skip("Not supported in SANA.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Not supported in SANA.")
|
||||
@unittest.skip("Not supported in SANA.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Not supported in SANA.")
|
||||
@unittest.skip("Not supported in SANA.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in SANA.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in SANA.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in SANA.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in SANA.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in SANA.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skipif(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
|
||||
@unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
|
||||
def test_layerwise_casting_inference_denoiser(self):
|
||||
return super().test_layerwise_casting_inference_denoiser()
|
||||
|
||||
@@ -14,9 +14,9 @@
|
||||
# limitations under the License.
|
||||
import gc
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import hf_hub_download
|
||||
@@ -55,7 +55,7 @@ if is_accelerate_available():
|
||||
from accelerate.utils import release_memory
|
||||
|
||||
|
||||
class TestStableDiffusionLoRA(PeftLoraLoaderMixinTests):
|
||||
class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
pipeline_class = StableDiffusionPipeline
|
||||
scheduler_cls = DDIMScheduler
|
||||
scheduler_kwargs = {
|
||||
@@ -91,6 +91,16 @@ class TestStableDiffusionLoRA(PeftLoraLoaderMixinTests):
|
||||
def output_shape(self):
|
||||
return (1, 64, 64, 3)
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
# Keeping this test here makes sense because it doesn't look any integration
|
||||
# (value assertions on logits).
|
||||
@slow
|
||||
@@ -104,8 +114,15 @@ class TestStableDiffusionLoRA(PeftLoraLoaderMixinTests):
|
||||
pipe.load_lora_weights(lora_id, adapter_name="adapter-2")
|
||||
pipe = pipe.to(torch_device)
|
||||
|
||||
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
assert check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in unet"
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder),
|
||||
"Lora not correctly set in text encoder",
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.unet),
|
||||
"Lora not correctly set in unet",
|
||||
)
|
||||
|
||||
# We will offload the first adapter in CPU and check if the offloading
|
||||
# has been performed correctly
|
||||
@@ -113,35 +130,35 @@ class TestStableDiffusionLoRA(PeftLoraLoaderMixinTests):
|
||||
|
||||
for name, module in pipe.unet.named_modules():
|
||||
if "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
|
||||
assert module.weight.device == torch.device("cpu")
|
||||
self.assertTrue(module.weight.device == torch.device("cpu"))
|
||||
elif "adapter-2" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
|
||||
assert module.weight.device != torch.device("cpu")
|
||||
self.assertTrue(module.weight.device != torch.device("cpu"))
|
||||
|
||||
for name, module in pipe.text_encoder.named_modules():
|
||||
if "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
|
||||
assert module.weight.device == torch.device("cpu")
|
||||
self.assertTrue(module.weight.device == torch.device("cpu"))
|
||||
elif "adapter-2" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
|
||||
assert module.weight.device != torch.device("cpu")
|
||||
self.assertTrue(module.weight.device != torch.device("cpu"))
|
||||
|
||||
pipe.set_lora_device(["adapter-1"], 0)
|
||||
|
||||
for n, m in pipe.unet.named_modules():
|
||||
if "adapter-1" in n and not isinstance(m, (nn.Dropout, nn.Identity)):
|
||||
assert m.weight.device != torch.device("cpu")
|
||||
self.assertTrue(m.weight.device != torch.device("cpu"))
|
||||
|
||||
for n, m in pipe.text_encoder.named_modules():
|
||||
if "adapter-1" in n and not isinstance(m, (nn.Dropout, nn.Identity)):
|
||||
assert m.weight.device != torch.device("cpu")
|
||||
self.assertTrue(m.weight.device != torch.device("cpu"))
|
||||
|
||||
pipe.set_lora_device(["adapter-1", "adapter-2"], torch_device)
|
||||
|
||||
for n, m in pipe.unet.named_modules():
|
||||
if ("adapter-1" in n or "adapter-2" in n) and not isinstance(m, (nn.Dropout, nn.Identity)):
|
||||
assert m.weight.device != torch.device("cpu")
|
||||
self.assertTrue(m.weight.device != torch.device("cpu"))
|
||||
|
||||
for n, m in pipe.text_encoder.named_modules():
|
||||
if ("adapter-1" in n or "adapter-2" in n) and not isinstance(m, (nn.Dropout, nn.Identity)):
|
||||
assert m.weight.device != torch.device("cpu")
|
||||
self.assertTrue(m.weight.device != torch.device("cpu"))
|
||||
|
||||
@slow
|
||||
@require_torch_accelerator
|
||||
@@ -164,9 +181,15 @@ class TestStableDiffusionLoRA(PeftLoraLoaderMixinTests):
|
||||
pipe.unet.add_adapter(unet_lora_config, "adapter-1")
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
|
||||
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder),
|
||||
"Lora not correctly set in text encoder",
|
||||
)
|
||||
|
||||
assert check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in unet"
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.unet),
|
||||
"Lora not correctly set in unet",
|
||||
)
|
||||
|
||||
for name, param in pipe.unet.named_parameters():
|
||||
if "lora_" in name:
|
||||
@@ -202,14 +225,17 @@ class TestStableDiffusionLoRA(PeftLoraLoaderMixinTests):
|
||||
pipe.unet.add_adapter(config1, adapter_name="adapter-1")
|
||||
pipe = pipe.to(torch_device)
|
||||
|
||||
assert check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in unet"
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.unet),
|
||||
"Lora not correctly set in unet",
|
||||
)
|
||||
|
||||
# sanity check that the adapters don't target the same layers, otherwise the test passes even without the fix
|
||||
modules_adapter_0 = {n for n, _ in pipe.unet.named_modules() if n.endswith(".adapter-0")}
|
||||
modules_adapter_1 = {n for n, _ in pipe.unet.named_modules() if n.endswith(".adapter-1")}
|
||||
assert modules_adapter_0 != modules_adapter_1
|
||||
assert modules_adapter_0 - modules_adapter_1
|
||||
assert modules_adapter_1 - modules_adapter_0
|
||||
self.assertNotEqual(modules_adapter_0, modules_adapter_1)
|
||||
self.assertTrue(modules_adapter_0 - modules_adapter_1)
|
||||
self.assertTrue(modules_adapter_1 - modules_adapter_0)
|
||||
|
||||
# setting both separately works
|
||||
pipe.set_lora_device(["adapter-0"], "cpu")
|
||||
@@ -217,30 +243,32 @@ class TestStableDiffusionLoRA(PeftLoraLoaderMixinTests):
|
||||
|
||||
for name, module in pipe.unet.named_modules():
|
||||
if "adapter-0" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
|
||||
assert module.weight.device == torch.device("cpu")
|
||||
self.assertTrue(module.weight.device == torch.device("cpu"))
|
||||
elif "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
|
||||
assert module.weight.device == torch.device("cpu")
|
||||
self.assertTrue(module.weight.device == torch.device("cpu"))
|
||||
|
||||
# setting both at once also works
|
||||
pipe.set_lora_device(["adapter-0", "adapter-1"], torch_device)
|
||||
|
||||
for name, module in pipe.unet.named_modules():
|
||||
if "adapter-0" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
|
||||
assert module.weight.device != torch.device("cpu")
|
||||
self.assertTrue(module.weight.device != torch.device("cpu"))
|
||||
elif "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
|
||||
assert module.weight.device != torch.device("cpu")
|
||||
self.assertTrue(module.weight.device != torch.device("cpu"))
|
||||
|
||||
|
||||
@slow
|
||||
@nightly
|
||||
@require_torch_accelerator
|
||||
@require_peft_backend
|
||||
class TestSDLoraIntegration:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _gc_and_cache_cleanup(self, torch_device):
|
||||
class LoraIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
yield
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@@ -252,7 +280,10 @@ class TestSDLoraIntegration:
|
||||
pipe.load_lora_weights(lora_id)
|
||||
pipe = pipe.to(torch_device)
|
||||
|
||||
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder),
|
||||
"Lora not correctly set in text encoder",
|
||||
)
|
||||
|
||||
prompt = "a red sks dog"
|
||||
|
||||
@@ -281,7 +312,10 @@ class TestSDLoraIntegration:
|
||||
pipe.load_lora_weights(lora_id)
|
||||
pipe = pipe.to(torch_device)
|
||||
|
||||
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder),
|
||||
"Lora not correctly set in text encoder",
|
||||
)
|
||||
|
||||
prompt = "a red sks dog"
|
||||
|
||||
@@ -553,8 +587,8 @@ class TestSDLoraIntegration:
|
||||
).images
|
||||
unloaded_lora_images = unloaded_lora_images[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert not np.allclose(initial_images, lora_images)
|
||||
assert np.allclose(initial_images, unloaded_lora_images, atol=1e-3)
|
||||
self.assertFalse(np.allclose(initial_images, lora_images))
|
||||
self.assertTrue(np.allclose(initial_images, unloaded_lora_images, atol=1e-3))
|
||||
|
||||
release_memory(pipe)
|
||||
|
||||
@@ -591,8 +625,8 @@ class TestSDLoraIntegration:
|
||||
).images
|
||||
unloaded_lora_images = unloaded_lora_images[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert not np.allclose(initial_images, lora_images)
|
||||
assert np.allclose(initial_images, unloaded_lora_images, atol=1e-3)
|
||||
self.assertFalse(np.allclose(initial_images, lora_images))
|
||||
self.assertTrue(np.allclose(initial_images, unloaded_lora_images, atol=1e-3))
|
||||
|
||||
# make sure we can load a LoRA again after unloading and they don't have
|
||||
# any undesired effects.
|
||||
@@ -603,7 +637,7 @@ class TestSDLoraIntegration:
|
||||
).images
|
||||
lora_images_again = lora_images_again[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert np.allclose(lora_images, lora_images_again, atol=1e-3)
|
||||
self.assertTrue(np.allclose(lora_images, lora_images_again, atol=1e-3))
|
||||
release_memory(pipe)
|
||||
|
||||
def test_not_empty_state_dict(self):
|
||||
@@ -617,7 +651,7 @@ class TestSDLoraIntegration:
|
||||
lcm_lora = load_file(cached_file)
|
||||
|
||||
pipe.load_lora_weights(lcm_lora, adapter_name="lcm")
|
||||
assert lcm_lora != {}
|
||||
self.assertTrue(lcm_lora != {})
|
||||
release_memory(pipe)
|
||||
|
||||
def test_load_unload_load_state_dict(self):
|
||||
@@ -632,11 +666,11 @@ class TestSDLoraIntegration:
|
||||
previous_state_dict = lcm_lora.copy()
|
||||
|
||||
pipe.load_lora_weights(lcm_lora, adapter_name="lcm")
|
||||
assert lcm_lora == previous_state_dict
|
||||
self.assertDictEqual(lcm_lora, previous_state_dict)
|
||||
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(lcm_lora, adapter_name="lcm")
|
||||
assert lcm_lora == previous_state_dict
|
||||
self.assertDictEqual(lcm_lora, previous_state_dict)
|
||||
|
||||
release_memory(pipe)
|
||||
|
||||
|
||||
@@ -14,9 +14,9 @@
|
||||
# limitations under the License.
|
||||
import gc
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
|
||||
|
||||
@@ -51,7 +51,7 @@ if is_accelerate_available():
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
class TestSD3LoRA(PeftLoraLoaderMixinTests):
|
||||
class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = StableDiffusion3Pipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
@@ -113,19 +113,19 @@ class TestSD3LoRA(PeftLoraLoaderMixinTests):
|
||||
lora_filename = "lora_peft_format.safetensors"
|
||||
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
|
||||
|
||||
@pytest.mark.skip("Not supported in SD3.")
|
||||
@unittest.skip("Not supported in SD3.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Not supported in SD3.")
|
||||
@unittest.skip("Not supported in SD3.")
|
||||
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Not supported in SD3.")
|
||||
@unittest.skip("Not supported in SD3.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Not supported in SD3.")
|
||||
@unittest.skip("Not supported in SD3.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@@ -138,15 +138,17 @@ class TestSD3LoRA(PeftLoraLoaderMixinTests):
|
||||
@require_torch_accelerator
|
||||
@require_peft_backend
|
||||
@require_big_accelerator
|
||||
class TestSD3LoraIntegration:
|
||||
class SD3LoraIntegrationTests(unittest.TestCase):
|
||||
pipeline_class = StableDiffusion3Img2ImgPipeline
|
||||
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _gc_and_cache_cleanup(self, torch_device):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
yield
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
|
||||
@@ -17,9 +17,9 @@ import gc
|
||||
import importlib
|
||||
import sys
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
@@ -59,7 +59,7 @@ if is_accelerate_available():
|
||||
from accelerate.utils import release_memory
|
||||
|
||||
|
||||
class TestStableDiffusionXLLoRA(PeftLoraLoaderMixinTests):
|
||||
class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
has_two_text_encoders = True
|
||||
pipeline_class = StableDiffusionXLPipeline
|
||||
scheduler_cls = EulerDiscreteScheduler
|
||||
@@ -104,11 +104,21 @@ class TestStableDiffusionXLLoRA(PeftLoraLoaderMixinTests):
|
||||
def output_shape(self):
|
||||
return (1, 64, 64, 3)
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@is_flaky
|
||||
def test_multiple_wrong_adapter_name_raises_error(self):
|
||||
super().test_multiple_wrong_adapter_name_raises_error()
|
||||
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self):
|
||||
if torch.cuda.is_available():
|
||||
expected_atol = 9e-2
|
||||
expected_rtol = 9e-2
|
||||
@@ -117,10 +127,10 @@ class TestStableDiffusionXLLoRA(PeftLoraLoaderMixinTests):
|
||||
expected_rtol = 1e-3
|
||||
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(
|
||||
pipe=pipe, expected_atol=expected_atol, expected_rtol=expected_rtol
|
||||
expected_atol=expected_atol, expected_rtol=expected_rtol
|
||||
)
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
if torch.cuda.is_available():
|
||||
expected_atol = 9e-2
|
||||
expected_rtol = 9e-2
|
||||
@@ -129,10 +139,10 @@ class TestStableDiffusionXLLoRA(PeftLoraLoaderMixinTests):
|
||||
expected_rtol = 1e-3
|
||||
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(
|
||||
pipe=pipe, expected_atol=expected_atol, expected_rtol=expected_rtol
|
||||
expected_atol=expected_atol, expected_rtol=expected_rtol
|
||||
)
|
||||
|
||||
def test_lora_scale_kwargs_match_fusion(self, base_pipe_output):
|
||||
def test_lora_scale_kwargs_match_fusion(self):
|
||||
if torch.cuda.is_available():
|
||||
expected_atol = 9e-2
|
||||
expected_rtol = 9e-2
|
||||
@@ -140,21 +150,21 @@ class TestStableDiffusionXLLoRA(PeftLoraLoaderMixinTests):
|
||||
expected_atol = 1e-3
|
||||
expected_rtol = 1e-3
|
||||
|
||||
super().test_lora_scale_kwargs_match_fusion(
|
||||
base_pipe_output=base_pipe_output, expected_atol=expected_atol, expected_rtol=expected_rtol
|
||||
)
|
||||
super().test_lora_scale_kwargs_match_fusion(expected_atol=expected_atol, expected_rtol=expected_rtol)
|
||||
|
||||
|
||||
@slow
|
||||
@nightly
|
||||
@require_torch_accelerator
|
||||
@require_peft_backend
|
||||
class TestLoraSDXLIntegration:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _gc_and_cache_cleanup(self, torch_device):
|
||||
class LoraSDXLIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
yield
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@@ -373,7 +383,7 @@ class TestLoraSDXLIntegration:
|
||||
end_time = time.time()
|
||||
elapsed_time_fusion = end_time - start_time
|
||||
|
||||
assert elapsed_time_fusion < elapsed_time_non_fusion
|
||||
self.assertTrue(elapsed_time_fusion < elapsed_time_non_fusion)
|
||||
|
||||
release_memory(pipe)
|
||||
|
||||
@@ -429,14 +439,14 @@ class TestLoraSDXLIntegration:
|
||||
|
||||
for key, value in text_encoder_1_sd.items():
|
||||
key = remap_key(key, fused_te_state_dict)
|
||||
assert torch.allclose(fused_te_state_dict[key], value)
|
||||
self.assertTrue(torch.allclose(fused_te_state_dict[key], value))
|
||||
|
||||
for key, value in text_encoder_2_sd.items():
|
||||
key = remap_key(key, fused_te_2_state_dict)
|
||||
assert torch.allclose(fused_te_2_state_dict[key], value)
|
||||
self.assertTrue(torch.allclose(fused_te_2_state_dict[key], value))
|
||||
|
||||
for key, value in unet_state_dict.items():
|
||||
assert torch.allclose(unet_state_dict[key], value)
|
||||
self.assertTrue(torch.allclose(unet_state_dict[key], value))
|
||||
|
||||
pipe.fuse_lora()
|
||||
pipe.unload_lora_weights()
|
||||
@@ -579,7 +589,7 @@ class TestLoraSDXLIntegration:
|
||||
pipe.load_lora_weights(lora_id, weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
|
||||
pipe = pipe.to(torch_device)
|
||||
|
||||
assert check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet"
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
|
||||
|
||||
prompt = "toy_face of a hacker with a hoodie"
|
||||
|
||||
|
||||
@@ -13,8 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
@@ -39,7 +39,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
@require_peft_backend
|
||||
@skip_mps
|
||||
class TestWanLoRA(PeftLoraLoaderMixinTests):
|
||||
class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = WanPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
@@ -104,40 +104,40 @@ class TestWanLoRA(PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3)
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
|
||||
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
|
||||
|
||||
@pytest.mark.skip("Not supported in Wan.")
|
||||
@unittest.skip("Not supported in Wan.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Not supported in Wan.")
|
||||
@unittest.skip("Not supported in Wan.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Not supported in Wan.")
|
||||
@unittest.skip("Not supported in Wan.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Wan.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Wan.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Wan.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Wan.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Wan.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -14,9 +14,10 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from PIL import Image
|
||||
@@ -31,6 +32,7 @@ from ..testing_utils import (
|
||||
require_peft_backend,
|
||||
require_peft_version_greater,
|
||||
skip_mps,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
@@ -45,7 +47,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
@require_peft_backend
|
||||
@skip_mps
|
||||
@is_flaky(max_attempts=10, description="very flaky class")
|
||||
class TestWanVACELoRA(PeftLoraLoaderMixinTests):
|
||||
class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = WanVACEPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
@@ -119,51 +121,56 @@ class TestWanVACELoRA(PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3)
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
|
||||
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
|
||||
|
||||
@pytest.mark.skip("Not supported in Wan VACE.")
|
||||
@unittest.skip("Not supported in Wan VACE.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Not supported in Wan VACE.")
|
||||
@unittest.skip("Not supported in Wan VACE.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Not supported in Wan VACE.")
|
||||
@unittest.skip("Not supported in Wan VACE.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
def test_layerwise_casting_inference_denoiser(self):
|
||||
super().test_layerwise_casting_inference_denoiser()
|
||||
|
||||
@require_peft_version_greater("0.13.2")
|
||||
def test_lora_exclude_modules_wanvace(self, base_pipe_output, tmpdirname, pipe):
|
||||
def test_lora_exclude_modules_wanvace(self):
|
||||
exclude_module_name = "vace_blocks.0.proj_out"
|
||||
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components).to(torch_device)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
assert base_pipe_output.shape == self.output_shape
|
||||
output_no_lora = self.get_base_pipe_output()
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
# only supported for `denoiser` now
|
||||
denoiser_lora_config.target_modules = ["proj_out"]
|
||||
@@ -173,30 +180,36 @@ class TestWanVACELoRA(PeftLoraLoaderMixinTests):
|
||||
)
|
||||
# The state dict shouldn't contain the modules to be excluded from LoRA.
|
||||
state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default")
|
||||
assert not any(exclude_module_name in k for k in state_dict_from_model)
|
||||
assert any("proj_out" in k for k in state_dict_from_model)
|
||||
self.assertTrue(not any(exclude_module_name in k for k in state_dict_from_model))
|
||||
self.assertTrue(any("proj_out" in k for k in state_dict_from_model))
|
||||
output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts)
|
||||
pipe.unload_lora_weights()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts)
|
||||
pipe.unload_lora_weights()
|
||||
|
||||
# Check in the loaded state dict.
|
||||
loaded_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
assert not any(exclude_module_name in k for k in loaded_state_dict)
|
||||
assert any("proj_out" in k for k in loaded_state_dict)
|
||||
# Check in the loaded state dict.
|
||||
loaded_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
self.assertTrue(not any(exclude_module_name in k for k in loaded_state_dict))
|
||||
self.assertTrue(any("proj_out" in k for k in loaded_state_dict))
|
||||
|
||||
# Check in the state dict obtained after loading LoRA.
|
||||
pipe.load_lora_weights(tmpdirname)
|
||||
state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default_0")
|
||||
assert not any(exclude_module_name in k for k in state_dict_from_model)
|
||||
assert any("proj_out" in k for k in state_dict_from_model)
|
||||
# Check in the state dict obtained after loading LoRA.
|
||||
pipe.load_lora_weights(tmpdir)
|
||||
state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default_0")
|
||||
self.assertTrue(not any(exclude_module_name in k for k in state_dict_from_model))
|
||||
self.assertTrue(any("proj_out" in k for k in state_dict_from_model))
|
||||
|
||||
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
assert not np.allclose(base_pipe_output, output_lora_exclude_modules, atol=1e-3, rtol=1e-3), (
|
||||
"LoRA should change outputs."
|
||||
)
|
||||
assert np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3), (
|
||||
"Lora outputs should match."
|
||||
)
|
||||
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(
|
||||
not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3),
|
||||
"LoRA should change outputs.",
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3),
|
||||
"Lora outputs should match.",
|
||||
)
|
||||
|
||||
def test_simple_inference_with_text_denoiser_lora_and_scale(self):
|
||||
super().test_simple_inference_with_text_denoiser_lora_and_scale()
|
||||
|
||||
+1013
-578
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user