Improve docstrings and type hints in scheduling_dpmsolver_multistep.py (#12710)
* Improve docstrings and type hints in multiple diffusion schedulers * docs: update Imagen Video paper link to Hugging Face Papers.
This commit is contained in:
parent
c8656ed73c
commit
a88a7b4f03
@ -429,7 +429,22 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return x_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
def index_for_timestep(
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The timestep for which to find the index.
|
||||
schedule_timesteps (`torch.Tensor`, *optional*):
|
||||
The timestep schedule to search in. If `None`, uses `self.timesteps`.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The index of the timestep in the schedule.
|
||||
"""
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
@ -452,6 +467,10 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def _init_step_index(self, timestep):
|
||||
"""
|
||||
Initialize the step_index counter for the scheduler.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The current timestep for which to initialize the step index.
|
||||
"""
|
||||
|
||||
if self.begin_index is None:
|
||||
|
||||
@ -401,6 +401,17 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
"""
|
||||
Convert sigma values to alpha_t and sigma_t values.
|
||||
|
||||
Args:
|
||||
sigma (`torch.Tensor`):
|
||||
The sigma value(s) to convert.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, torch.Tensor]`:
|
||||
A tuple containing (alpha_t, sigma_t) values.
|
||||
"""
|
||||
if self.config.use_flow_sigmas:
|
||||
alpha_t = 1 - sigma
|
||||
sigma_t = sigma
|
||||
@ -808,7 +819,22 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
raise NotImplementedError("only support log-rho multistep deis now")
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
def index_for_timestep(
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The timestep for which to find the index.
|
||||
schedule_timesteps (`torch.Tensor`, *optional*):
|
||||
The timestep schedule to search in. If `None`, uses `self.timesteps`.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The index of the timestep in the schedule.
|
||||
"""
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
@ -831,6 +857,10 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def _init_step_index(self, timestep):
|
||||
"""
|
||||
Initialize the step_index counter for the scheduler.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The current timestep for which to initialize the step index.
|
||||
"""
|
||||
|
||||
if self.begin_index is None:
|
||||
@ -927,6 +957,21 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Add noise to the original samples according to the noise schedule at the specified timesteps.
|
||||
|
||||
Args:
|
||||
original_samples (`torch.Tensor`):
|
||||
The original samples without noise.
|
||||
noise (`torch.Tensor`):
|
||||
The noise to add to the samples.
|
||||
timesteps (`torch.IntTensor`):
|
||||
The timesteps at which to add noise to the samples.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The noisy samples.
|
||||
"""
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||
|
||||
@ -127,18 +127,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
The starting `beta` value of inference.
|
||||
beta_end (`float`, defaults to 0.02):
|
||||
The final `beta` value.
|
||||
beta_schedule (`str`, defaults to `"linear"`):
|
||||
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
||||
beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
|
||||
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model.
|
||||
trained_betas (`np.ndarray`, *optional*):
|
||||
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
||||
solver_order (`int`, defaults to 2):
|
||||
The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
|
||||
sampling, and `solver_order=3` for unconditional sampling.
|
||||
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
||||
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
||||
`sample` (directly predicts the noisy sample), `v_prediction` (see section 2.4 of [Imagen
|
||||
Video](https://imagen.research.google/video/paper.pdf) paper), or `flow_prediction`.
|
||||
prediction_type (`"epsilon"`, `"sample"`, `"v_prediction"`, or `"flow_prediction"`, defaults to `"epsilon"`):
|
||||
Prediction type of the scheduler function. `epsilon` predicts the noise of the diffusion process, `sample`
|
||||
directly predicts the noisy sample, `v_prediction` predicts the velocity (see section 2.4 of [Imagen
|
||||
Video](https://huggingface.co/papers/2210.02303) paper), and `flow_prediction` predicts the flow.
|
||||
thresholding (`bool`, defaults to `False`):
|
||||
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
||||
as Stable Diffusion.
|
||||
@ -147,15 +146,14 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample_max_value (`float`, defaults to 1.0):
|
||||
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
|
||||
`algorithm_type="dpmsolver++"`.
|
||||
algorithm_type (`str`, defaults to `dpmsolver++`):
|
||||
Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The
|
||||
`dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
|
||||
paper, and the `dpmsolver++` type implements the algorithms in the
|
||||
[DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
|
||||
`sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
|
||||
solver_type (`str`, defaults to `midpoint`):
|
||||
Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
|
||||
sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
|
||||
algorithm_type (`"dpmsolver"`, `"dpmsolver++"`, `"sde-dpmsolver"`, or `"sde-dpmsolver++"`, defaults to `"dpmsolver++"`):
|
||||
Algorithm type for the solver. The `dpmsolver` type implements the algorithms in the
|
||||
[DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the `dpmsolver++` type implements the
|
||||
algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use
|
||||
`dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
|
||||
solver_type (`"midpoint"` or `"heun"`, defaults to `"midpoint"`):
|
||||
Solver type for the second-order solver. The solver type slightly affects the sample quality, especially
|
||||
for a small number of steps. It is recommended to use `midpoint` solvers.
|
||||
lower_order_final (`bool`, defaults to `True`):
|
||||
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
|
||||
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
|
||||
@ -179,16 +177,16 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
|
||||
flow_shift (`float`, *optional*, defaults to 1.0):
|
||||
The shift value for the timestep schedule for flow matching.
|
||||
final_sigmas_type (`str`, defaults to `"zero"`):
|
||||
final_sigmas_type (`"zero"` or `"sigma_min"`, *optional*, defaults to `"zero"`):
|
||||
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
||||
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
|
||||
sigma is the same as the last sigma in the training schedule. If `"zero"`, the final sigma is set to 0.
|
||||
lambda_min_clipped (`float`, defaults to `-inf`):
|
||||
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
|
||||
cosine (`squaredcos_cap_v2`) noise schedule.
|
||||
variance_type (`str`, *optional*):
|
||||
Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
|
||||
contains the predicted Gaussian variance.
|
||||
timestep_spacing (`str`, defaults to `"linspace"`):
|
||||
variance_type (`"learned"` or `"learned_range"`, *optional*):
|
||||
Set to `"learned"` or `"learned_range"` for diffusion models that predict variance. If set, the model's
|
||||
output contains the predicted Gaussian variance.
|
||||
timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"linspace"`):
|
||||
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
||||
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||||
steps_offset (`int`, defaults to 0):
|
||||
@ -197,6 +195,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
||||
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
||||
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
||||
use_dynamic_shifting (`bool`, defaults to `False`):
|
||||
Whether to use dynamic shifting for the timestep schedule.
|
||||
time_shift_type (`"exponential"`, defaults to `"exponential"`):
|
||||
The type of time shift to apply when using dynamic shifting.
|
||||
"""
|
||||
|
||||
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||
@ -208,15 +210,15 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
solver_order: int = 2,
|
||||
prediction_type: str = "epsilon",
|
||||
prediction_type: Literal["epsilon", "sample", "v_prediction", "flow_prediction"] = "epsilon",
|
||||
thresholding: bool = False,
|
||||
dynamic_thresholding_ratio: float = 0.995,
|
||||
sample_max_value: float = 1.0,
|
||||
algorithm_type: str = "dpmsolver++",
|
||||
solver_type: str = "midpoint",
|
||||
algorithm_type: Literal["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"] = "dpmsolver++",
|
||||
solver_type: Literal["midpoint", "heun"] = "midpoint",
|
||||
lower_order_final: bool = True,
|
||||
euler_at_final: bool = False,
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
@ -225,14 +227,14 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
use_lu_lambdas: Optional[bool] = False,
|
||||
use_flow_sigmas: Optional[bool] = False,
|
||||
flow_shift: Optional[float] = 1.0,
|
||||
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
||||
final_sigmas_type: Optional[Literal["zero", "sigma_min"]] = "zero",
|
||||
lambda_min_clipped: float = -float("inf"),
|
||||
variance_type: Optional[str] = None,
|
||||
timestep_spacing: str = "linspace",
|
||||
variance_type: Optional[Literal["learned", "learned_range"]] = None,
|
||||
timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
|
||||
steps_offset: int = 0,
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
use_dynamic_shifting: bool = False,
|
||||
time_shift_type: str = "exponential",
|
||||
time_shift_type: Literal["exponential"] = "exponential",
|
||||
):
|
||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||
@ -331,19 +333,22 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int = None,
|
||||
device: Union[str, torch.device] = None,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
mu: Optional[float] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
num_inference_steps (`int`, *optional*):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
mu (`float`, *optional*):
|
||||
The mu parameter for dynamic shifting. If provided, requires `use_dynamic_shifting=True` and
|
||||
`time_shift_type="exponential"`.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated
|
||||
based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
|
||||
@ -503,7 +508,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return sample
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Convert sigma values to corresponding timestep values through interpolation.
|
||||
|
||||
@ -539,7 +544,18 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
t = t.reshape(sigma.shape)
|
||||
return t
|
||||
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Convert sigma values to alpha_t and sigma_t values.
|
||||
|
||||
Args:
|
||||
sigma (`torch.Tensor`):
|
||||
The sigma value(s) to convert.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, torch.Tensor]`:
|
||||
A tuple containing (alpha_t, sigma_t) values.
|
||||
"""
|
||||
if self.config.use_flow_sigmas:
|
||||
alpha_t = 1 - sigma
|
||||
sigma_t = sigma
|
||||
@ -588,8 +604,21 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||
return sigmas
|
||||
|
||||
def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch.Tensor:
|
||||
"""Constructs the noise schedule of Lu et al. (2022)."""
|
||||
def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
|
||||
"""
|
||||
Construct the noise schedule as proposed in [DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model
|
||||
Sampling in Around 10 Steps](https://huggingface.co/papers/2206.00927) by Lu et al. (2022).
|
||||
|
||||
Args:
|
||||
in_lambdas (`torch.Tensor`):
|
||||
The input lambda values to be converted.
|
||||
num_inference_steps (`int`):
|
||||
The number of inference steps to generate the noise schedule for.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The converted lambda values following the Lu noise schedule.
|
||||
"""
|
||||
|
||||
lambda_min: float = in_lambdas[-1].item()
|
||||
lambda_max: float = in_lambdas[0].item()
|
||||
@ -1069,7 +1098,22 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
return x_t
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
def index_for_timestep(
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The timestep for which to find the index.
|
||||
schedule_timesteps (`torch.Tensor`, *optional*):
|
||||
The timestep schedule to search in. If `None`, uses `self.timesteps`.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The index of the timestep in the schedule.
|
||||
"""
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
@ -1088,9 +1132,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return step_index
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
def _init_step_index(self, timestep: Union[int, torch.Tensor]) -> None:
|
||||
"""
|
||||
Initialize the step_index counter for the scheduler.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The current timestep for which to initialize the step index.
|
||||
"""
|
||||
|
||||
if self.begin_index is None:
|
||||
@ -1105,7 +1153,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
model_output: torch.Tensor,
|
||||
timestep: Union[int, torch.Tensor],
|
||||
sample: torch.Tensor,
|
||||
generator=None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
variance_noise: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
@ -1115,22 +1163,22 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
Args:
|
||||
model_output (`torch.Tensor`):
|
||||
The direct output from learned diffusion model.
|
||||
timestep (`int`):
|
||||
The direct output from the learned diffusion model.
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.Tensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A random number generator.
|
||||
variance_noise (`torch.Tensor`):
|
||||
variance_noise (`torch.Tensor`, *optional*):
|
||||
Alternative to generating noise with `generator` by directly providing the noise for the variance
|
||||
itself. Useful for methods such as [`LEdits++`].
|
||||
return_dict (`bool`):
|
||||
return_dict (`bool`, defaults to `True`):
|
||||
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
|
||||
|
||||
Returns:
|
||||
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
||||
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
|
||||
If `return_dict` is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
|
||||
tuple is returned where the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
@ -1210,6 +1258,21 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Add noise to the original samples according to the noise schedule at the specified timesteps.
|
||||
|
||||
Args:
|
||||
original_samples (`torch.Tensor`):
|
||||
The original samples without noise.
|
||||
noise (`torch.Tensor`):
|
||||
The noise to add to the samples.
|
||||
timesteps (`torch.IntTensor`):
|
||||
The timesteps at which to add noise to the samples.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The noisy samples.
|
||||
"""
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||
|
||||
@ -413,6 +413,17 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
"""
|
||||
Convert sigma values to alpha_t and sigma_t values.
|
||||
|
||||
Args:
|
||||
sigma (`torch.Tensor`):
|
||||
The sigma value(s) to convert.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, torch.Tensor]`:
|
||||
A tuple containing (alpha_t, sigma_t) values.
|
||||
"""
|
||||
if self.config.use_flow_sigmas:
|
||||
alpha_t = 1 - sigma
|
||||
sigma_t = sigma
|
||||
|
||||
@ -491,6 +491,17 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
"""
|
||||
Convert sigma values to alpha_t and sigma_t values.
|
||||
|
||||
Args:
|
||||
sigma (`torch.Tensor`):
|
||||
The sigma value(s) to convert.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, torch.Tensor]`:
|
||||
A tuple containing (alpha_t, sigma_t) values.
|
||||
"""
|
||||
if self.config.use_flow_sigmas:
|
||||
alpha_t = 1 - sigma
|
||||
sigma_t = sigma
|
||||
@ -1079,7 +1090,22 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
raise ValueError(f"Order must be 1, 2, 3, got {order}")
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
def index_for_timestep(
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The timestep for which to find the index.
|
||||
schedule_timesteps (`torch.Tensor`, *optional*):
|
||||
The timestep schedule to search in. If `None`, uses `self.timesteps`.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The index of the timestep in the schedule.
|
||||
"""
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
@ -1102,6 +1128,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def _init_step_index(self, timestep):
|
||||
"""
|
||||
Initialize the step_index counter for the scheduler.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The current timestep for which to initialize the step index.
|
||||
"""
|
||||
|
||||
if self.begin_index is None:
|
||||
@ -1204,6 +1234,21 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Add noise to the original samples according to the noise schedule at the specified timesteps.
|
||||
|
||||
Args:
|
||||
original_samples (`torch.Tensor`):
|
||||
The original samples without noise.
|
||||
noise (`torch.Tensor`):
|
||||
The noise to add to the samples.
|
||||
timesteps (`torch.IntTensor`):
|
||||
The timesteps at which to add noise to the samples.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The noisy samples.
|
||||
"""
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||
|
||||
@ -578,7 +578,22 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return x_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
def index_for_timestep(
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The timestep for which to find the index.
|
||||
schedule_timesteps (`torch.Tensor`, *optional*):
|
||||
The timestep schedule to search in. If `None`, uses `self.timesteps`.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The index of the timestep in the schedule.
|
||||
"""
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
@ -601,6 +616,10 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def _init_step_index(self, timestep):
|
||||
"""
|
||||
Initialize the step_index counter for the scheduler.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The current timestep for which to initialize the step index.
|
||||
"""
|
||||
|
||||
if self.begin_index is None:
|
||||
|
||||
@ -423,6 +423,17 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
"""
|
||||
Convert sigma values to alpha_t and sigma_t values.
|
||||
|
||||
Args:
|
||||
sigma (`torch.Tensor`):
|
||||
The sigma value(s) to convert.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, torch.Tensor]`:
|
||||
A tuple containing (alpha_t, sigma_t) values.
|
||||
"""
|
||||
if self.config.use_flow_sigmas:
|
||||
alpha_t = 1 - sigma
|
||||
sigma_t = sigma
|
||||
@ -1103,7 +1114,22 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
||||
return x_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
def index_for_timestep(
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The timestep for which to find the index.
|
||||
schedule_timesteps (`torch.Tensor`, *optional*):
|
||||
The timestep schedule to search in. If `None`, uses `self.timesteps`.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The index of the timestep in the schedule.
|
||||
"""
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
@ -1126,6 +1152,10 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
||||
def _init_step_index(self, timestep):
|
||||
"""
|
||||
Initialize the step_index counter for the scheduler.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The current timestep for which to initialize the step index.
|
||||
"""
|
||||
|
||||
if self.begin_index is None:
|
||||
|
||||
@ -513,6 +513,17 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
"""
|
||||
Convert sigma values to alpha_t and sigma_t values.
|
||||
|
||||
Args:
|
||||
sigma (`torch.Tensor`):
|
||||
The sigma value(s) to convert.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, torch.Tensor]`:
|
||||
A tuple containing (alpha_t, sigma_t) values.
|
||||
"""
|
||||
if self.config.use_flow_sigmas:
|
||||
alpha_t = 1 - sigma
|
||||
sigma_t = sigma
|
||||
@ -984,7 +995,22 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return x_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
def index_for_timestep(
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The timestep for which to find the index.
|
||||
schedule_timesteps (`torch.Tensor`, *optional*):
|
||||
The timestep schedule to search in. If `None`, uses `self.timesteps`.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The index of the timestep in the schedule.
|
||||
"""
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
@ -1007,6 +1033,10 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def _init_step_index(self, timestep):
|
||||
"""
|
||||
Initialize the step_index counter for the scheduler.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The current timestep for which to initialize the step index.
|
||||
"""
|
||||
|
||||
if self.begin_index is None:
|
||||
@ -1119,6 +1149,21 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Add noise to the original samples according to the noise schedule at the specified timesteps.
|
||||
|
||||
Args:
|
||||
original_samples (`torch.Tensor`):
|
||||
The original samples without noise.
|
||||
noise (`torch.Tensor`):
|
||||
The noise to add to the samples.
|
||||
timesteps (`torch.IntTensor`):
|
||||
The timesteps at which to add noise to the samples.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The noisy samples.
|
||||
"""
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user