[docs] Add missing copied from statements in TCD Scheduler (#7360)
* add missing copied from statements in tcd scheduler * update docstring --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -307,6 +307,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
The input sample.
|
The input sample.
|
||||||
timestep (`int`, *optional*):
|
timestep (`int`, *optional*):
|
||||||
The current timestep in the diffusion chain.
|
The current timestep in the diffusion chain.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`torch.FloatTensor`:
|
`torch.FloatTensor`:
|
||||||
A scaled input sample.
|
A scaled input sample.
|
||||||
@@ -364,7 +365,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
device: Union[str, torch.device] = None,
|
device: Union[str, torch.device] = None,
|
||||||
original_inference_steps: Optional[int] = None,
|
original_inference_steps: Optional[int] = None,
|
||||||
timesteps: Optional[List[int]] = None,
|
timesteps: Optional[List[int]] = None,
|
||||||
strength: int = 1.0,
|
strength: float = 1.0,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||||
@@ -384,6 +385,8 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
||||||
timestep spacing strategy of equal spacing between timesteps on the training/distillation timestep
|
timestep spacing strategy of equal spacing between timesteps on the training/distillation timestep
|
||||||
schedule is used. If `timesteps` is passed, `num_inference_steps` must be `None`.
|
schedule is used. If `timesteps` is passed, `num_inference_steps` must be `None`.
|
||||||
|
strength (`float`, *optional*, defaults to 1.0):
|
||||||
|
Used to determine the number of timesteps used for inference when using img2img, inpaint, etc.
|
||||||
"""
|
"""
|
||||||
# 0. Check inputs
|
# 0. Check inputs
|
||||||
if num_inference_steps is None and timesteps is None:
|
if num_inference_steps is None and timesteps is None:
|
||||||
@@ -624,6 +627,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
|
|
||||||
return TCDSchedulerOutput(prev_sample=prev_sample, pred_noised_sample=pred_noised_sample)
|
return TCDSchedulerOutput(prev_sample=prev_sample, pred_noised_sample=pred_noised_sample)
|
||||||
|
|
||||||
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
||||||
def add_noise(
|
def add_noise(
|
||||||
self,
|
self,
|
||||||
original_samples: torch.FloatTensor,
|
original_samples: torch.FloatTensor,
|
||||||
@@ -631,7 +635,10 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
timesteps: torch.IntTensor,
|
timesteps: torch.IntTensor,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
||||||
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
|
||||||
|
# for the subsequent add_noise calls
|
||||||
|
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
|
||||||
|
alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
|
||||||
timesteps = timesteps.to(original_samples.device)
|
timesteps = timesteps.to(original_samples.device)
|
||||||
|
|
||||||
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
||||||
@@ -647,11 +654,13 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||||
return noisy_samples
|
return noisy_samples
|
||||||
|
|
||||||
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
|
||||||
def get_velocity(
|
def get_velocity(
|
||||||
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
|
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
# Make sure alphas_cumprod and timestep have same device and dtype as sample
|
# Make sure alphas_cumprod and timestep have same device and dtype as sample
|
||||||
alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
|
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
|
||||||
|
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
|
||||||
timesteps = timesteps.to(sample.device)
|
timesteps = timesteps.to(sample.device)
|
||||||
|
|
||||||
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
||||||
@@ -670,6 +679,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.config.num_train_timesteps
|
return self.config.num_train_timesteps
|
||||||
|
|
||||||
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
|
||||||
def previous_timestep(self, timestep):
|
def previous_timestep(self, timestep):
|
||||||
if self.custom_timesteps:
|
if self.custom_timesteps:
|
||||||
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
|
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
|
||||||
|
|||||||
Reference in New Issue
Block a user