FIX Test to ignore warning for enable_lora_hotswap (#12421)
I noticed that the test should be for the option check_compiled="ignore" but it was using check_compiled="warn". This has been fixed, now the correct argument is passed. However, the fact that the test passed means that it was incorrect to begin with. The way that logs are collected does not collect the logger.warning call here (not sure why). To amend this, I'm now using assertNoLogs. With this change, the test correctly fails when the wrong argument is passed.
This commit is contained in:
@@ -25,7 +25,6 @@ import traceback
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
import uuid
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
@@ -2373,14 +2372,15 @@ class LoraHotSwappingForModelTesterMixin:
|
||||
|
||||
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self):
|
||||
# check possibility to ignore the error/warning
|
||||
from diffusers.loaders.peft import logger
|
||||
|
||||
lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.add_adapter(lora_config)
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always") # Capture all warnings
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
|
||||
self.assertEqual(len(w), 0, f"Expected no warnings, but got: {[str(warn.message) for warn in w]}")
|
||||
# note: assertNoLogs requires Python 3.10+
|
||||
with self.assertNoLogs(logger, level="WARNING"):
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
|
||||
|
||||
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
|
||||
# check that wrong argument value raises an error
|
||||
|
||||
Reference in New Issue
Block a user