From cedafb8600be106081479d56d6c2c7331226045a Mon Sep 17 00:00:00 2001 From: Dudu Moshe <53430514+dudulightricks@users.noreply.github.com> Date: Tue, 31 Jan 2023 10:13:26 +0200 Subject: [PATCH] [Bug]: fix DDPM scheduler arbitrary infer steps count. (#2076) scheduling_ddpm: fix evaluate with lower timesteps count than train. Co-authored-by: Patrick von Platen --- src/diffusers/schedulers/scheduling_ddpm.py | 21 ++++++++----- .../audio_diffusion/test_audio_diffusion.py | 2 +- tests/pipelines/ddpm/test_ddpm.py | 30 +++++++++++++++++-- tests/test_scheduler.py | 8 ++--- 4 files changed, 47 insertions(+), 14 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 2f802ba126..63b31033c9 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -189,13 +189,16 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): self.timesteps = torch.from_numpy(timesteps).to(device) def _get_variance(self, t, predicted_variance=None, variance_type=None): + num_inference_steps = self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps + prev_t = t - self.config.num_train_timesteps // num_inference_steps alpha_prod_t = self.alphas_cumprod[t] - alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one + alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one + current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) # and sample from it to get previous sample # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample - variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t] + variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t if variance_type is None: variance_type = self.config.variance_type @@ -208,10 +211,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): variance = torch.log(torch.clamp(variance, min=1e-20)) variance = torch.exp(0.5 * variance) elif variance_type == "fixed_large": - variance = self.betas[t] + variance = current_beta_t elif variance_type == "fixed_large_log": # Glide max_log - variance = torch.log(self.betas[t]) + variance = torch.log(current_beta_t) elif variance_type == "learned": return predicted_variance elif variance_type == "learned_range": @@ -249,6 +252,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): """ t = timestep + num_inference_steps = self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps + prev_t = timestep - self.config.num_train_timesteps // num_inference_steps if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) @@ -257,9 +262,11 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): # 1. compute alphas, betas alpha_prod_t = self.alphas_cumprod[t] - alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one + alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev + current_alpha_t = alpha_prod_t / alpha_prod_t_prev + current_beta_t = 1 - current_alpha_t # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf @@ -281,8 +288,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf - pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t - current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t + current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t # 5. Compute predicted previous sample µ_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf diff --git a/tests/pipelines/audio_diffusion/test_audio_diffusion.py b/tests/pipelines/audio_diffusion/test_audio_diffusion.py index dc706d868a..b68e940bdc 100644 --- a/tests/pipelines/audio_diffusion/test_audio_diffusion.py +++ b/tests/pipelines/audio_diffusion/test_audio_diffusion.py @@ -118,7 +118,7 @@ class PipelineFastTests(unittest.TestCase): assert image.height == self.dummy_unet.sample_size[0] and image.width == self.dummy_unet.sample_size[1] image_slice = np.frombuffer(image.tobytes(), dtype="uint8")[:10] image_from_tuple_slice = np.frombuffer(image_from_tuple.tobytes(), dtype="uint8")[:10] - expected_slice = np.array([255, 255, 255, 0, 181, 0, 124, 0, 15, 255]) + expected_slice = np.array([69, 255, 255, 255, 0, 0, 77, 181, 12, 127]) assert np.abs(image_slice.flatten() - expected_slice).max() == 0 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() == 0 diff --git a/tests/pipelines/ddpm/test_ddpm.py b/tests/pipelines/ddpm/test_ddpm.py index 0bc7d0475f..a16b3782a4 100644 --- a/tests/pipelines/ddpm/test_ddpm.py +++ b/tests/pipelines/ddpm/test_ddpm.py @@ -40,7 +40,7 @@ class DDPMPipelineFastTests(unittest.TestCase): ) return model - def test_inference(self): + def test_fast_inference(self): device = "cpu" unet = self.dummy_uncond_unet scheduler = DDPMScheduler() @@ -60,7 +60,33 @@ class DDPMPipelineFastTests(unittest.TestCase): assert image.shape == (1, 32, 32, 3) expected_slice = np.array( - [5.589e-01, 7.089e-01, 2.632e-01, 6.841e-01, 1.000e-04, 9.999e-01, 1.973e-01, 1.000e-04, 8.010e-02] + [9.956e-01, 5.785e-01, 4.675e-01, 9.930e-01, 0.0, 1.000, 1.199e-03, 2.648e-04, 5.101e-04] + ) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + def test_full_inference(self): + device = "cpu" + unet = self.dummy_uncond_unet + scheduler = DDPMScheduler() + + ddpm = DDPMPipeline(unet=unet, scheduler=scheduler) + ddpm.to(device) + ddpm.set_progress_bar_config(disable=None) + + generator = torch.Generator(device=device).manual_seed(0) + image = ddpm(generator=generator, output_type="numpy").images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = ddpm(generator=generator, output_type="numpy", return_dict=False)[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 32, 32, 3) + expected_slice = np.array( + [1.0, 3.495e-02, 2.939e-01, 9.821e-01, 9.448e-01, 6.261e-03, 7.998e-01, 8.9e-01, 1.122e-02] ) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index d49d599c57..f38b6b6b34 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -686,8 +686,8 @@ class DDPMSchedulerTest(SchedulerCommonTest): result_sum = torch.sum(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample)) - assert abs(result_sum.item() - 258.9070) < 1e-2 - assert abs(result_mean.item() - 0.3374) < 1e-3 + assert abs(result_sum.item() - 258.9606) < 1e-2 + assert abs(result_mean.item() - 0.3372) < 1e-3 def test_full_loop_with_v_prediction(self): scheduler_class = self.scheduler_classes[0] @@ -717,8 +717,8 @@ class DDPMSchedulerTest(SchedulerCommonTest): result_sum = torch.sum(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample)) - assert abs(result_sum.item() - 201.9864) < 1e-2 - assert abs(result_mean.item() - 0.2630) < 1e-3 + assert abs(result_sum.item() - 202.0296) < 1e-2 + assert abs(result_mean.item() - 0.2631) < 1e-3 class DDIMSchedulerTest(SchedulerCommonTest):