Fix dreambooth loss type with prior_preservation and fp16 (#826)
Fix dreambooth loss type with prior preservation
This commit is contained in:
@@ -544,7 +544,7 @@ def main():
|
|||||||
noise, noise_prior = torch.chunk(noise, 2, dim=0)
|
noise, noise_prior = torch.chunk(noise, 2, dim=0)
|
||||||
|
|
||||||
# Compute instance loss
|
# Compute instance loss
|
||||||
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
|
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean()
|
||||||
|
|
||||||
# Compute prior loss
|
# Compute prior loss
|
||||||
prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")
|
prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")
|
||||||
|
|||||||
Reference in New Issue
Block a user