add sigmoid betas (#777)
* add sigmoid betas * convert to torch * add comment on source
This commit is contained in:
@@ -133,6 +133,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
elif beta_schedule == "squaredcos_cap_v2":
|
elif beta_schedule == "squaredcos_cap_v2":
|
||||||
# Glide cosine schedule
|
# Glide cosine schedule
|
||||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||||
|
elif beta_schedule == "sigmoid":
|
||||||
|
# GeoDiff sigmoid schedule
|
||||||
|
betas = torch.linspace(-6, 6, num_train_timesteps)
|
||||||
|
self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user