Correct sigmas cpu settings (#6708)
This commit is contained in:
committed by
GitHub
parent
87bfbc320d
commit
3e9716f22b
@@ -98,7 +98,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
self.custom_timesteps = False
|
self.custom_timesteps = False
|
||||||
self.is_scale_input_called = False
|
self.is_scale_input_called = False
|
||||||
self._step_index = None
|
self._step_index = None
|
||||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||||
|
|
||||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||||
if schedule_timesteps is None:
|
if schedule_timesteps is None:
|
||||||
@@ -231,7 +231,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
||||||
|
|
||||||
self._step_index = None
|
self._step_index = None
|
||||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||||
|
|
||||||
# Modified _convert_to_karras implementation that takes in ramp as argument
|
# Modified _convert_to_karras implementation that takes in ramp as argument
|
||||||
def _convert_to_karras(self, ramp):
|
def _convert_to_karras(self, ramp):
|
||||||
|
|||||||
@@ -187,7 +187,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
self.model_outputs = [None] * solver_order
|
self.model_outputs = [None] * solver_order
|
||||||
self.lower_order_nums = 0
|
self.lower_order_nums = 0
|
||||||
self._step_index = None
|
self._step_index = None
|
||||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def step_index(self):
|
def step_index(self):
|
||||||
@@ -255,7 +255,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
|
|
||||||
# add an index counter for schedulers that allow duplicated timesteps
|
# add an index counter for schedulers that allow duplicated timesteps
|
||||||
self._step_index = None
|
self._step_index = None
|
||||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
|||||||
@@ -227,7 +227,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
self.model_outputs = [None] * solver_order
|
self.model_outputs = [None] * solver_order
|
||||||
self.lower_order_nums = 0
|
self.lower_order_nums = 0
|
||||||
self._step_index = None
|
self._step_index = None
|
||||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def step_index(self):
|
def step_index(self):
|
||||||
@@ -311,7 +311,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
|
|
||||||
# add an index counter for schedulers that allow duplicated timesteps
|
# add an index counter for schedulers that allow duplicated timesteps
|
||||||
self._step_index = None
|
self._step_index = None
|
||||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
|||||||
@@ -213,7 +213,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
self.model_outputs = [None] * solver_order
|
self.model_outputs = [None] * solver_order
|
||||||
self.lower_order_nums = 0
|
self.lower_order_nums = 0
|
||||||
self._step_index = None
|
self._step_index = None
|
||||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||||
self.use_karras_sigmas = use_karras_sigmas
|
self.use_karras_sigmas = use_karras_sigmas
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -294,7 +294,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
|
|
||||||
# add an index counter for schedulers that allow duplicated timesteps
|
# add an index counter for schedulers that allow duplicated timesteps
|
||||||
self._step_index = None
|
self._step_index = None
|
||||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
|||||||
@@ -198,7 +198,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
self.noise_sampler = None
|
self.noise_sampler = None
|
||||||
self.noise_sampler_seed = noise_sampler_seed
|
self.noise_sampler_seed = noise_sampler_seed
|
||||||
self._step_index = None
|
self._step_index = None
|
||||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
|
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
|
||||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||||
@@ -348,7 +348,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
self.mid_point_sigma = None
|
self.mid_point_sigma = None
|
||||||
|
|
||||||
self._step_index = None
|
self._step_index = None
|
||||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||||
self.noise_sampler = None
|
self.noise_sampler = None
|
||||||
|
|
||||||
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
|
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
|
||||||
|
|||||||
@@ -210,7 +210,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
self.sample = None
|
self.sample = None
|
||||||
self.order_list = self.get_order_list(num_train_timesteps)
|
self.order_list = self.get_order_list(num_train_timesteps)
|
||||||
self._step_index = None
|
self._step_index = None
|
||||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||||
|
|
||||||
def get_order_list(self, num_inference_steps: int) -> List[int]:
|
def get_order_list(self, num_inference_steps: int) -> List[int]:
|
||||||
"""
|
"""
|
||||||
@@ -315,7 +315,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
|
|
||||||
# add an index counter for schedulers that allow duplicated timesteps
|
# add an index counter for schedulers that allow duplicated timesteps
|
||||||
self._step_index = None
|
self._step_index = None
|
||||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
|||||||
@@ -216,7 +216,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
self.is_scale_input_called = False
|
self.is_scale_input_called = False
|
||||||
|
|
||||||
self._step_index = None
|
self._step_index = None
|
||||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def init_noise_sigma(self):
|
def init_noise_sigma(self):
|
||||||
@@ -300,7 +300,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
|
|
||||||
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
||||||
self._step_index = None
|
self._step_index = None
|
||||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||||
def _init_step_index(self, timestep):
|
def _init_step_index(self, timestep):
|
||||||
|
|||||||
@@ -237,7 +237,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
self.use_karras_sigmas = use_karras_sigmas
|
self.use_karras_sigmas = use_karras_sigmas
|
||||||
|
|
||||||
self._step_index = None
|
self._step_index = None
|
||||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def init_noise_sigma(self):
|
def init_noise_sigma(self):
|
||||||
@@ -342,7 +342,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
|
|
||||||
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
||||||
self._step_index = None
|
self._step_index = None
|
||||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||||
|
|
||||||
def _sigma_to_t(self, sigma, log_sigmas):
|
def _sigma_to_t(self, sigma, log_sigmas):
|
||||||
# get log sigma
|
# get log sigma
|
||||||
|
|||||||
@@ -148,7 +148,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
self.use_karras_sigmas = use_karras_sigmas
|
self.use_karras_sigmas = use_karras_sigmas
|
||||||
|
|
||||||
self._step_index = None
|
self._step_index = None
|
||||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||||
|
|
||||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||||
if schedule_timesteps is None:
|
if schedule_timesteps is None:
|
||||||
@@ -270,7 +270,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
self.dt = None
|
self.dt = None
|
||||||
|
|
||||||
self._step_index = None
|
self._step_index = None
|
||||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||||
|
|
||||||
# (YiYi Notes: keep this for now since we are keeping add_noise function which use index_for_timestep)
|
# (YiYi Notes: keep this for now since we are keeping add_noise function which use index_for_timestep)
|
||||||
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
|
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
|
||||||
|
|||||||
@@ -140,7 +140,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
# set all values
|
# set all values
|
||||||
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
|
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
|
||||||
self._step_index = None
|
self._step_index = None
|
||||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
|
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
|
||||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||||
@@ -300,7 +300,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
self._index_counter = defaultdict(int)
|
self._index_counter = defaultdict(int)
|
||||||
|
|
||||||
self._step_index = None
|
self._step_index = None
|
||||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
||||||
def _sigma_to_t(self, sigma, log_sigmas):
|
def _sigma_to_t(self, sigma, log_sigmas):
|
||||||
|
|||||||
@@ -140,7 +140,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
|
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
|
||||||
|
|
||||||
self._step_index = None
|
self._step_index = None
|
||||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
|
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
|
||||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||||
@@ -285,7 +285,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
self._index_counter = defaultdict(int)
|
self._index_counter = defaultdict(int)
|
||||||
|
|
||||||
self._step_index = None
|
self._step_index = None
|
||||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def state_in_first_order(self):
|
def state_in_first_order(self):
|
||||||
|
|||||||
@@ -168,7 +168,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
self.is_scale_input_called = False
|
self.is_scale_input_called = False
|
||||||
|
|
||||||
self._step_index = None
|
self._step_index = None
|
||||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def init_noise_sigma(self):
|
def init_noise_sigma(self):
|
||||||
@@ -280,7 +280,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
self.sigmas = torch.from_numpy(sigmas).to(device=device)
|
self.sigmas = torch.from_numpy(sigmas).to(device=device)
|
||||||
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
||||||
self._step_index = None
|
self._step_index = None
|
||||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||||
|
|
||||||
self.derivatives = []
|
self.derivatives = []
|
||||||
|
|
||||||
|
|||||||
@@ -212,7 +212,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
self.lower_order_nums = 0
|
self.lower_order_nums = 0
|
||||||
self.last_sample = None
|
self.last_sample = None
|
||||||
self._step_index = None
|
self._step_index = None
|
||||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def step_index(self):
|
def step_index(self):
|
||||||
@@ -283,7 +283,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
|
|
||||||
# add an index counter for schedulers that allow duplicated timesteps
|
# add an index counter for schedulers that allow duplicated timesteps
|
||||||
self._step_index = None
|
self._step_index = None
|
||||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
|||||||
@@ -198,7 +198,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
self.solver_p = solver_p
|
self.solver_p = solver_p
|
||||||
self.last_sample = None
|
self.last_sample = None
|
||||||
self._step_index = None
|
self._step_index = None
|
||||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def step_index(self):
|
def step_index(self):
|
||||||
@@ -269,7 +269,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
|
|
||||||
# add an index counter for schedulers that allow duplicated timesteps
|
# add an index counter for schedulers that allow duplicated timesteps
|
||||||
self._step_index = None
|
self._step_index = None
|
||||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
|||||||
Reference in New Issue
Block a user