fix: DDPMScheduler.set_timesteps() (#1912)
This commit is contained in:
@@ -201,6 +201,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
num_inference_steps (`int`):
|
num_inference_steps (`int`):
|
||||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if num_inference_steps > self.config.num_train_timesteps:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
|
||||||
|
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
||||||
|
f" maximal {self.config.num_train_timesteps} timesteps."
|
||||||
|
)
|
||||||
|
|
||||||
self.num_inference_steps = num_inference_steps
|
self.num_inference_steps = num_inference_steps
|
||||||
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
||||||
# creates integer timesteps by multiplying by ratio
|
# creates integer timesteps by multiplying by ratio
|
||||||
|
|||||||
@@ -184,11 +184,18 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
num_inference_steps (`int`):
|
num_inference_steps (`int`):
|
||||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||||
"""
|
"""
|
||||||
num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
|
|
||||||
|
if num_inference_steps > self.config.num_train_timesteps:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
|
||||||
|
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
||||||
|
f" maximal {self.config.num_train_timesteps} timesteps."
|
||||||
|
)
|
||||||
|
|
||||||
self.num_inference_steps = num_inference_steps
|
self.num_inference_steps = num_inference_steps
|
||||||
timesteps = np.arange(
|
|
||||||
0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
|
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
||||||
)[::-1].copy()
|
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
|
||||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||||
|
|
||||||
def _get_variance(self, t, predicted_variance=None, variance_type=None):
|
def _get_variance(self, t, predicted_variance=None, variance_type=None):
|
||||||
|
|||||||
Reference in New Issue
Block a user