Fix EMAModel test_from_pretrained (#10325)

This commit is contained in:
hlky 2024-12-21 14:10:44 +00:00 committed by GitHub
parent a756694bf0
commit bf9a641f1a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -67,6 +67,7 @@ class EMAModelTests(unittest.TestCase):
# Load the EMA model from the saved directory
loaded_ema_unet = EMAModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel, foreach=False)
loaded_ema_unet.to(torch_device)
# Check that the shadow parameters of the loaded model match the original EMA model
for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params):
@ -221,6 +222,7 @@ class EMAModelTestsForeach(unittest.TestCase):
# Load the EMA model from the saved directory
loaded_ema_unet = EMAModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel, foreach=True)
loaded_ema_unet.to(torch_device)
# Check that the shadow parameters of the loaded model match the original EMA model
for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params):