Fix dreambooth loss type with prior_preservation and fp16 (#826)

Fix dreambooth loss type with prior preservation
This commit is contained in:
Anton Lozhkov
2022-10-13 15:41:19 +02:00
committed by GitHub
parent 0a09af2f0a
commit e001fededf
+1 -1
View File
@@ -544,7 +544,7 @@ def main():
noise, noise_prior = torch.chunk(noise, 2, dim=0) noise, noise_prior = torch.chunk(noise, 2, dim=0)
# Compute instance loss # Compute instance loss
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean()
# Compute prior loss # Compute prior loss
prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")