From e4546fd5bb5c9e4b74b5e843823d84ad0d20ccbd Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 19 Mar 2024 16:15:36 +0530 Subject: [PATCH] [docs] Add missing copied from statements in TCD Scheduler (#7360) * add missing copied from statements in tcd scheduler * update docstring --------- Co-authored-by: Sayak Paul --- src/diffusers/schedulers/scheduling_tcd.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_tcd.py b/src/diffusers/schedulers/scheduling_tcd.py index 40abf2d6b6..ee3cde5d21 100644 --- a/src/diffusers/schedulers/scheduling_tcd.py +++ b/src/diffusers/schedulers/scheduling_tcd.py @@ -307,6 +307,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin): The input sample. timestep (`int`, *optional*): The current timestep in the diffusion chain. + Returns: `torch.FloatTensor`: A scaled input sample. @@ -364,7 +365,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin): device: Union[str, torch.device] = None, original_inference_steps: Optional[int] = None, timesteps: Optional[List[int]] = None, - strength: int = 1.0, + strength: float = 1.0, ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -384,6 +385,8 @@ class TCDScheduler(SchedulerMixin, ConfigMixin): Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default timestep spacing strategy of equal spacing between timesteps on the training/distillation timestep schedule is used. If `timesteps` is passed, `num_inference_steps` must be `None`. + strength (`float`, *optional*, defaults to 1.0): + Used to determine the number of timesteps used for inference when using img2img, inpaint, etc. """ # 0. Check inputs if num_inference_steps is None and timesteps is None: @@ -624,6 +627,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin): return TCDSchedulerOutput(prev_sample=prev_sample, pred_noised_sample=pred_noised_sample) + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, @@ -631,7 +635,10 @@ class TCDScheduler(SchedulerMixin, ConfigMixin): timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples - alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement + # for the subsequent add_noise calls + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 @@ -647,11 +654,13 @@ class TCDScheduler(SchedulerMixin, ConfigMixin): noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity def get_velocity( self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as sample - alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) + self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) timesteps = timesteps.to(sample.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 @@ -670,6 +679,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin): def __len__(self): return self.config.num_train_timesteps + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep def previous_timestep(self, timestep): if self.custom_timesteps: index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]