[Refactor] Remove set_seed (#289)

* [Refactor] Remove set_seed and class attributes

* apply anton's suggestiosn

* fix

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* up

* update

* make style

* Apply suggestions from code review

Co-authored-by: Anton Lozhkov <anton@huggingface.co>

* make fix-copies

* make style

* make style and new copies

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: Anton Lozhkov <anton@huggingface.co>
This commit is contained in:
Patrick von Platen
2022-08-31 19:29:38 +02:00
committed by GitHub
parent 384fcac6df
commit f3937bc8f3
7 changed files with 25 additions and 16 deletions
@@ -56,7 +56,7 @@ class DDPMPipeline(DiffusionPipeline):
model_output = self.unet(image, t)["sample"] model_output = self.unet(image, t)["sample"]
# 2. compute previous image: x_t -> t_t-1 # 2. compute previous image: x_t -> t_t-1
image = self.scheduler.step(model_output, t, image)["prev_sample"] image = self.scheduler.step(model_output, t, image, generator=generator)["prev_sample"]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy() image = image.cpu().permute(0, 2, 3, 1).numpy()
@@ -30,7 +30,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
model = self.unet model = self.unet
sample = torch.randn(*shape) * self.scheduler.config.sigma_max sample = torch.randn(*shape, generator=generator) * self.scheduler.config.sigma_max
sample = sample.to(self.device) sample = sample.to(self.device)
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
@@ -42,11 +42,11 @@ class ScoreSdeVePipeline(DiffusionPipeline):
# correction step # correction step
for _ in range(self.scheduler.correct_steps): for _ in range(self.scheduler.correct_steps):
model_output = self.unet(sample, sigma_t)["sample"] model_output = self.unet(sample, sigma_t)["sample"]
sample = self.scheduler.step_correct(model_output, sample)["prev_sample"] sample = self.scheduler.step_correct(model_output, sample, generator=generator)["prev_sample"]
# prediction step # prediction step
model_output = model(sample, sigma_t)["sample"] model_output = model(sample, sigma_t)["sample"]
output = self.scheduler.step_pred(model_output, t, sample) output = self.scheduler.step_pred(model_output, t, sample, generator=generator)
sample, sample_mean = output["prev_sample"], output["prev_sample_mean"] sample, sample_mean = output["prev_sample"], output["prev_sample_mean"]
@@ -19,6 +19,7 @@ class KarrasVePipeline(DiffusionPipeline):
differential equations." https://arxiv.org/abs/2011.13456 differential equations." https://arxiv.org/abs/2011.13456
""" """
# add type hints for linting
unet: UNet2DModel unet: UNet2DModel
scheduler: KarrasVeScheduler scheduler: KarrasVeScheduler
+17 -11
View File
@@ -14,8 +14,8 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit import warnings
from typing import Union from typing import Optional, Union
import numpy as np import numpy as np
import torch import torch
@@ -98,6 +98,11 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def set_seed(self, seed): def set_seed(self, seed):
warnings.warn(
"The method `set_seed` is deprecated and will be removed in version `0.4.0`. Please consider passing a"
" generator instead.",
DeprecationWarning,
)
tensor_format = getattr(self, "tensor_format", "pt") tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np": if tensor_format == "np":
np.random.seed(seed) np.random.seed(seed)
@@ -111,14 +116,14 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
model_output: Union[torch.FloatTensor, np.ndarray], model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int, timestep: int,
sample: Union[torch.FloatTensor, np.ndarray], sample: Union[torch.FloatTensor, np.ndarray],
seed=None, generator: Optional[torch.Generator] = None,
**kwargs,
): ):
""" """
Predict the sample at the previous timestep by reversing the SDE. Predict the sample at the previous timestep by reversing the SDE.
""" """
if seed is not None: if "seed" in kwargs and kwargs["seed"] is not None:
self.set_seed(seed) self.set_seed(kwargs["seed"])
# TODO(Patrick) non-PyTorch
if self.timesteps is None: if self.timesteps is None:
raise ValueError( raise ValueError(
@@ -140,7 +145,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
drift = drift - diffusion[:, None, None, None] ** 2 * model_output drift = drift - diffusion[:, None, None, None] ** 2 * model_output
# equation 6: sample noise for the diffusion term of # equation 6: sample noise for the diffusion term of
noise = self.randn_like(sample) noise = self.randn_like(sample, generator=generator)
prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep
# TODO is the variable diffusion the correct scaling term for the noise? # TODO is the variable diffusion the correct scaling term for the noise?
prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g
@@ -151,14 +156,15 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
self, self,
model_output: Union[torch.FloatTensor, np.ndarray], model_output: Union[torch.FloatTensor, np.ndarray],
sample: Union[torch.FloatTensor, np.ndarray], sample: Union[torch.FloatTensor, np.ndarray],
seed=None, generator: Optional[torch.Generator] = None,
**kwargs,
): ):
""" """
Correct the predicted sample based on the output model_output of the network. This is often run repeatedly Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
after making the prediction for the previous timestep. after making the prediction for the previous timestep.
""" """
if seed is not None: if "seed" in kwargs and kwargs["seed"] is not None:
self.set_seed(seed) self.set_seed(kwargs["seed"])
if self.timesteps is None: if self.timesteps is None:
raise ValueError( raise ValueError(
@@ -167,7 +173,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
# For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z" # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
# sample noise for correction # sample noise for correction
noise = self.randn_like(sample) noise = self.randn_like(sample, generator=generator)
# compute step size from the model_output, the noise, and the snr # compute step size from the model_output, the noise, and the snr
grad_norm = self.norm(model_output) grad_norm = self.norm(model_output)
@@ -1,5 +1,6 @@
# This file is autogenerated by the command `make fix-copies`, do not edit. # This file is autogenerated by the command `make fix-copies`, do not edit.
# flake8: noqa # flake8: noqa
from ..utils import DummyObject, requires_backends from ..utils import DummyObject, requires_backends
@@ -1,5 +1,6 @@
# This file is autogenerated by the command `make fix-copies`, do not edit. # This file is autogenerated by the command `make fix-copies`, do not edit.
# flake8: noqa # flake8: noqa
from ..utils import DummyObject, requires_backends from ..utils import DummyObject, requires_backends
+1 -1
View File
@@ -107,7 +107,7 @@ def create_dummy_files():
for backend, objects in backend_specific_objects.items(): for backend, objects in backend_specific_objects.items():
backend_name = "[" + ", ".join(f'"{b}"' for b in backend.split("_and_")) + "]" backend_name = "[" + ", ".join(f'"{b}"' for b in backend.split("_and_")) + "]"
dummy_file = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n" dummy_file = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
dummy_file += "# flake8: noqa\n" dummy_file += "# flake8: noqa\n\n"
dummy_file += "from ..utils import DummyObject, requires_backends\n\n" dummy_file += "from ..utils import DummyObject, requires_backends\n\n"
dummy_file += "\n".join([create_dummy_object(o, backend_name) for o in objects]) dummy_file += "\n".join([create_dummy_object(o, backend_name) for o in objects])
dummy_files[backend] = dummy_file dummy_files[backend] = dummy_file