[LoRA] fix cross_attention_kwargs problems and tighten tests (#7388)
* debugging * let's see the numbers * let's see the numbers * let's see the numbers * restrict tolerance. * increase inference steps. * shallow copy of cross_attentionkwargs * remove print
This commit is contained in:
@@ -1178,6 +1178,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
|||||||
# we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
|
# we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
|
||||||
# to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
|
# to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
|
||||||
if cross_attention_kwargs is not None:
|
if cross_attention_kwargs is not None:
|
||||||
|
cross_attention_kwargs = cross_attention_kwargs.copy()
|
||||||
lora_scale = cross_attention_kwargs.pop("scale", 1.0)
|
lora_scale = cross_attention_kwargs.pop("scale", 1.0)
|
||||||
else:
|
else:
|
||||||
lora_scale = 1.0
|
lora_scale = 1.0
|
||||||
|
|||||||
@@ -158,7 +158,7 @@ class PeftLoraLoaderMixinTests:
|
|||||||
|
|
||||||
pipeline_inputs = {
|
pipeline_inputs = {
|
||||||
"prompt": "A painting of a squirrel eating a burger",
|
"prompt": "A painting of a squirrel eating a burger",
|
||||||
"num_inference_steps": 2,
|
"num_inference_steps": 5,
|
||||||
"guidance_scale": 6.0,
|
"guidance_scale": 6.0,
|
||||||
"output_type": "np",
|
"output_type": "np",
|
||||||
}
|
}
|
||||||
@@ -589,7 +589,7 @@ class PeftLoraLoaderMixinTests:
|
|||||||
**inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5}
|
**inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5}
|
||||||
).images
|
).images
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
|
not np.allclose(output_lora, output_lora_scale, atol=1e-4, rtol=1e-4),
|
||||||
"Lora + scale should change the output",
|
"Lora + scale should change the output",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1300,6 +1300,11 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
|||||||
pipe.load_lora_weights(lora_id)
|
pipe.load_lora_weights(lora_id)
|
||||||
pipe = pipe.to("cuda")
|
pipe = pipe.to("cuda")
|
||||||
|
|
||||||
|
self.assertTrue(
|
||||||
|
self.check_if_lora_correctly_set(pipe.unet),
|
||||||
|
"Lora not correctly set in UNet",
|
||||||
|
)
|
||||||
|
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
self.check_if_lora_correctly_set(pipe.text_encoder),
|
self.check_if_lora_correctly_set(pipe.text_encoder),
|
||||||
"Lora not correctly set in text encoder 2",
|
"Lora not correctly set in text encoder 2",
|
||||||
|
|||||||
Reference in New Issue
Block a user