[docs] Add missing copied from statements in TCD Scheduler (#7360)
* add missing copied from statements in tcd scheduler * update docstring --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -307,6 +307,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
|
||||
The input sample.
|
||||
timestep (`int`, *optional*):
|
||||
The current timestep in the diffusion chain.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`:
|
||||
A scaled input sample.
|
||||
@@ -364,7 +365,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
|
||||
device: Union[str, torch.device] = None,
|
||||
original_inference_steps: Optional[int] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
strength: int = 1.0,
|
||||
strength: float = 1.0,
|
||||
):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
@@ -384,6 +385,8 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
|
||||
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
||||
timestep spacing strategy of equal spacing between timesteps on the training/distillation timestep
|
||||
schedule is used. If `timesteps` is passed, `num_inference_steps` must be `None`.
|
||||
strength (`float`, *optional*, defaults to 1.0):
|
||||
Used to determine the number of timesteps used for inference when using img2img, inpaint, etc.
|
||||
"""
|
||||
# 0. Check inputs
|
||||
if num_inference_steps is None and timesteps is None:
|
||||
@@ -624,6 +627,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return TCDSchedulerOutput(prev_sample=prev_sample, pred_noised_sample=pred_noised_sample)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
@@ -631,7 +635,10 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.FloatTensor:
|
||||
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
||||
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
|
||||
# for the subsequent add_noise calls
|
||||
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
|
||||
alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
||||
@@ -647,11 +654,13 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
|
||||
def get_velocity(
|
||||
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
|
||||
) -> torch.FloatTensor:
|
||||
# Make sure alphas_cumprod and timestep have same device and dtype as sample
|
||||
alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
|
||||
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
|
||||
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
|
||||
timesteps = timesteps.to(sample.device)
|
||||
|
||||
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
||||
@@ -670,6 +679,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
|
||||
def previous_timestep(self, timestep):
|
||||
if self.custom_timesteps:
|
||||
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
|
||||
|
||||
Reference in New Issue
Block a user