Improve docstrings and type hints in scheduling_dpmsolver_multistep.py (#12710)
* Improve docstrings and type hints in multiple diffusion schedulers * docs: update Imagen Video paper link to Hugging Face Papers.
This commit is contained in:
parent
c8656ed73c
commit
a88a7b4f03
@ -429,7 +429,22 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
return x_t
|
return x_t
|
||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
def index_for_timestep(
|
||||||
|
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Find the index for a given timestep in the schedule.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timestep (`int` or `torch.Tensor`):
|
||||||
|
The timestep for which to find the index.
|
||||||
|
schedule_timesteps (`torch.Tensor`, *optional*):
|
||||||
|
The timestep schedule to search in. If `None`, uses `self.timesteps`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`int`:
|
||||||
|
The index of the timestep in the schedule.
|
||||||
|
"""
|
||||||
if schedule_timesteps is None:
|
if schedule_timesteps is None:
|
||||||
schedule_timesteps = self.timesteps
|
schedule_timesteps = self.timesteps
|
||||||
|
|
||||||
@ -452,6 +467,10 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
def _init_step_index(self, timestep):
|
def _init_step_index(self, timestep):
|
||||||
"""
|
"""
|
||||||
Initialize the step_index counter for the scheduler.
|
Initialize the step_index counter for the scheduler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timestep (`int` or `torch.Tensor`):
|
||||||
|
The current timestep for which to initialize the step index.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.begin_index is None:
|
if self.begin_index is None:
|
||||||
|
|||||||
@ -401,6 +401,17 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||||
|
"""
|
||||||
|
Convert sigma values to alpha_t and sigma_t values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sigma (`torch.Tensor`):
|
||||||
|
The sigma value(s) to convert.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Tuple[torch.Tensor, torch.Tensor]`:
|
||||||
|
A tuple containing (alpha_t, sigma_t) values.
|
||||||
|
"""
|
||||||
if self.config.use_flow_sigmas:
|
if self.config.use_flow_sigmas:
|
||||||
alpha_t = 1 - sigma
|
alpha_t = 1 - sigma
|
||||||
sigma_t = sigma
|
sigma_t = sigma
|
||||||
@ -808,7 +819,22 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
raise NotImplementedError("only support log-rho multistep deis now")
|
raise NotImplementedError("only support log-rho multistep deis now")
|
||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
def index_for_timestep(
|
||||||
|
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Find the index for a given timestep in the schedule.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timestep (`int` or `torch.Tensor`):
|
||||||
|
The timestep for which to find the index.
|
||||||
|
schedule_timesteps (`torch.Tensor`, *optional*):
|
||||||
|
The timestep schedule to search in. If `None`, uses `self.timesteps`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`int`:
|
||||||
|
The index of the timestep in the schedule.
|
||||||
|
"""
|
||||||
if schedule_timesteps is None:
|
if schedule_timesteps is None:
|
||||||
schedule_timesteps = self.timesteps
|
schedule_timesteps = self.timesteps
|
||||||
|
|
||||||
@ -831,6 +857,10 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
def _init_step_index(self, timestep):
|
def _init_step_index(self, timestep):
|
||||||
"""
|
"""
|
||||||
Initialize the step_index counter for the scheduler.
|
Initialize the step_index counter for the scheduler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timestep (`int` or `torch.Tensor`):
|
||||||
|
The current timestep for which to initialize the step index.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.begin_index is None:
|
if self.begin_index is None:
|
||||||
@ -927,6 +957,21 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
noise: torch.Tensor,
|
noise: torch.Tensor,
|
||||||
timesteps: torch.IntTensor,
|
timesteps: torch.IntTensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Add noise to the original samples according to the noise schedule at the specified timesteps.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
original_samples (`torch.Tensor`):
|
||||||
|
The original samples without noise.
|
||||||
|
noise (`torch.Tensor`):
|
||||||
|
The noise to add to the samples.
|
||||||
|
timesteps (`torch.IntTensor`):
|
||||||
|
The timesteps at which to add noise to the samples.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.Tensor`:
|
||||||
|
The noisy samples.
|
||||||
|
"""
|
||||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||||
|
|||||||
@ -127,18 +127,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
The starting `beta` value of inference.
|
The starting `beta` value of inference.
|
||||||
beta_end (`float`, defaults to 0.02):
|
beta_end (`float`, defaults to 0.02):
|
||||||
The final `beta` value.
|
The final `beta` value.
|
||||||
beta_schedule (`str`, defaults to `"linear"`):
|
beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
|
||||||
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model.
|
||||||
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
|
||||||
trained_betas (`np.ndarray`, *optional*):
|
trained_betas (`np.ndarray`, *optional*):
|
||||||
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
||||||
solver_order (`int`, defaults to 2):
|
solver_order (`int`, defaults to 2):
|
||||||
The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
|
The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
|
||||||
sampling, and `solver_order=3` for unconditional sampling.
|
sampling, and `solver_order=3` for unconditional sampling.
|
||||||
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
prediction_type (`"epsilon"`, `"sample"`, `"v_prediction"`, or `"flow_prediction"`, defaults to `"epsilon"`):
|
||||||
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
Prediction type of the scheduler function. `epsilon` predicts the noise of the diffusion process, `sample`
|
||||||
`sample` (directly predicts the noisy sample), `v_prediction` (see section 2.4 of [Imagen
|
directly predicts the noisy sample, `v_prediction` predicts the velocity (see section 2.4 of [Imagen
|
||||||
Video](https://imagen.research.google/video/paper.pdf) paper), or `flow_prediction`.
|
Video](https://huggingface.co/papers/2210.02303) paper), and `flow_prediction` predicts the flow.
|
||||||
thresholding (`bool`, defaults to `False`):
|
thresholding (`bool`, defaults to `False`):
|
||||||
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
||||||
as Stable Diffusion.
|
as Stable Diffusion.
|
||||||
@ -147,15 +146,14 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
sample_max_value (`float`, defaults to 1.0):
|
sample_max_value (`float`, defaults to 1.0):
|
||||||
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
|
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
|
||||||
`algorithm_type="dpmsolver++"`.
|
`algorithm_type="dpmsolver++"`.
|
||||||
algorithm_type (`str`, defaults to `dpmsolver++`):
|
algorithm_type (`"dpmsolver"`, `"dpmsolver++"`, `"sde-dpmsolver"`, or `"sde-dpmsolver++"`, defaults to `"dpmsolver++"`):
|
||||||
Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The
|
Algorithm type for the solver. The `dpmsolver` type implements the algorithms in the
|
||||||
`dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
|
[DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the `dpmsolver++` type implements the
|
||||||
paper, and the `dpmsolver++` type implements the algorithms in the
|
algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use
|
||||||
[DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
|
`dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
|
||||||
`sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
|
solver_type (`"midpoint"` or `"heun"`, defaults to `"midpoint"`):
|
||||||
solver_type (`str`, defaults to `midpoint`):
|
Solver type for the second-order solver. The solver type slightly affects the sample quality, especially
|
||||||
Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
|
for a small number of steps. It is recommended to use `midpoint` solvers.
|
||||||
sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
|
|
||||||
lower_order_final (`bool`, defaults to `True`):
|
lower_order_final (`bool`, defaults to `True`):
|
||||||
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
|
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
|
||||||
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
|
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
|
||||||
@ -179,16 +177,16 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
|
Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
|
||||||
flow_shift (`float`, *optional*, defaults to 1.0):
|
flow_shift (`float`, *optional*, defaults to 1.0):
|
||||||
The shift value for the timestep schedule for flow matching.
|
The shift value for the timestep schedule for flow matching.
|
||||||
final_sigmas_type (`str`, defaults to `"zero"`):
|
final_sigmas_type (`"zero"` or `"sigma_min"`, *optional*, defaults to `"zero"`):
|
||||||
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
||||||
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
|
sigma is the same as the last sigma in the training schedule. If `"zero"`, the final sigma is set to 0.
|
||||||
lambda_min_clipped (`float`, defaults to `-inf`):
|
lambda_min_clipped (`float`, defaults to `-inf`):
|
||||||
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
|
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
|
||||||
cosine (`squaredcos_cap_v2`) noise schedule.
|
cosine (`squaredcos_cap_v2`) noise schedule.
|
||||||
variance_type (`str`, *optional*):
|
variance_type (`"learned"` or `"learned_range"`, *optional*):
|
||||||
Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
|
Set to `"learned"` or `"learned_range"` for diffusion models that predict variance. If set, the model's
|
||||||
contains the predicted Gaussian variance.
|
output contains the predicted Gaussian variance.
|
||||||
timestep_spacing (`str`, defaults to `"linspace"`):
|
timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"linspace"`):
|
||||||
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
||||||
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||||||
steps_offset (`int`, defaults to 0):
|
steps_offset (`int`, defaults to 0):
|
||||||
@ -197,6 +195,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
||||||
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
||||||
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
||||||
|
use_dynamic_shifting (`bool`, defaults to `False`):
|
||||||
|
Whether to use dynamic shifting for the timestep schedule.
|
||||||
|
time_shift_type (`"exponential"`, defaults to `"exponential"`):
|
||||||
|
The type of time shift to apply when using dynamic shifting.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||||
@ -208,15 +210,15 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
num_train_timesteps: int = 1000,
|
num_train_timesteps: int = 1000,
|
||||||
beta_start: float = 0.0001,
|
beta_start: float = 0.0001,
|
||||||
beta_end: float = 0.02,
|
beta_end: float = 0.02,
|
||||||
beta_schedule: str = "linear",
|
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
|
||||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||||
solver_order: int = 2,
|
solver_order: int = 2,
|
||||||
prediction_type: str = "epsilon",
|
prediction_type: Literal["epsilon", "sample", "v_prediction", "flow_prediction"] = "epsilon",
|
||||||
thresholding: bool = False,
|
thresholding: bool = False,
|
||||||
dynamic_thresholding_ratio: float = 0.995,
|
dynamic_thresholding_ratio: float = 0.995,
|
||||||
sample_max_value: float = 1.0,
|
sample_max_value: float = 1.0,
|
||||||
algorithm_type: str = "dpmsolver++",
|
algorithm_type: Literal["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"] = "dpmsolver++",
|
||||||
solver_type: str = "midpoint",
|
solver_type: Literal["midpoint", "heun"] = "midpoint",
|
||||||
lower_order_final: bool = True,
|
lower_order_final: bool = True,
|
||||||
euler_at_final: bool = False,
|
euler_at_final: bool = False,
|
||||||
use_karras_sigmas: Optional[bool] = False,
|
use_karras_sigmas: Optional[bool] = False,
|
||||||
@ -225,14 +227,14 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
use_lu_lambdas: Optional[bool] = False,
|
use_lu_lambdas: Optional[bool] = False,
|
||||||
use_flow_sigmas: Optional[bool] = False,
|
use_flow_sigmas: Optional[bool] = False,
|
||||||
flow_shift: Optional[float] = 1.0,
|
flow_shift: Optional[float] = 1.0,
|
||||||
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
final_sigmas_type: Optional[Literal["zero", "sigma_min"]] = "zero",
|
||||||
lambda_min_clipped: float = -float("inf"),
|
lambda_min_clipped: float = -float("inf"),
|
||||||
variance_type: Optional[str] = None,
|
variance_type: Optional[Literal["learned", "learned_range"]] = None,
|
||||||
timestep_spacing: str = "linspace",
|
timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
|
||||||
steps_offset: int = 0,
|
steps_offset: int = 0,
|
||||||
rescale_betas_zero_snr: bool = False,
|
rescale_betas_zero_snr: bool = False,
|
||||||
use_dynamic_shifting: bool = False,
|
use_dynamic_shifting: bool = False,
|
||||||
time_shift_type: str = "exponential",
|
time_shift_type: Literal["exponential"] = "exponential",
|
||||||
):
|
):
|
||||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||||
@ -331,19 +333,22 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
|
|
||||||
def set_timesteps(
|
def set_timesteps(
|
||||||
self,
|
self,
|
||||||
num_inference_steps: int = None,
|
num_inference_steps: Optional[int] = None,
|
||||||
device: Union[str, torch.device] = None,
|
device: Optional[Union[str, torch.device]] = None,
|
||||||
mu: Optional[float] = None,
|
mu: Optional[float] = None,
|
||||||
timesteps: Optional[List[int]] = None,
|
timesteps: Optional[List[int]] = None,
|
||||||
):
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_inference_steps (`int`):
|
num_inference_steps (`int`, *optional*):
|
||||||
The number of diffusion steps used when generating samples with a pre-trained model.
|
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||||
device (`str` or `torch.device`, *optional*):
|
device (`str` or `torch.device`, *optional*):
|
||||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||||
|
mu (`float`, *optional*):
|
||||||
|
The mu parameter for dynamic shifting. If provided, requires `use_dynamic_shifting=True` and
|
||||||
|
`time_shift_type="exponential"`.
|
||||||
timesteps (`List[int]`, *optional*):
|
timesteps (`List[int]`, *optional*):
|
||||||
Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated
|
Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated
|
||||||
based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
|
based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
|
||||||
@ -503,7 +508,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
return sample
|
return sample
|
||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
||||||
def _sigma_to_t(self, sigma, log_sigmas):
|
def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Convert sigma values to corresponding timestep values through interpolation.
|
Convert sigma values to corresponding timestep values through interpolation.
|
||||||
|
|
||||||
@ -539,7 +544,18 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
t = t.reshape(sigma.shape)
|
t = t.reshape(sigma.shape)
|
||||||
return t
|
return t
|
||||||
|
|
||||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Convert sigma values to alpha_t and sigma_t values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sigma (`torch.Tensor`):
|
||||||
|
The sigma value(s) to convert.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Tuple[torch.Tensor, torch.Tensor]`:
|
||||||
|
A tuple containing (alpha_t, sigma_t) values.
|
||||||
|
"""
|
||||||
if self.config.use_flow_sigmas:
|
if self.config.use_flow_sigmas:
|
||||||
alpha_t = 1 - sigma
|
alpha_t = 1 - sigma
|
||||||
sigma_t = sigma
|
sigma_t = sigma
|
||||||
@ -588,8 +604,21 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||||
return sigmas
|
return sigmas
|
||||||
|
|
||||||
def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch.Tensor:
|
def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
|
||||||
"""Constructs the noise schedule of Lu et al. (2022)."""
|
"""
|
||||||
|
Construct the noise schedule as proposed in [DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model
|
||||||
|
Sampling in Around 10 Steps](https://huggingface.co/papers/2206.00927) by Lu et al. (2022).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_lambdas (`torch.Tensor`):
|
||||||
|
The input lambda values to be converted.
|
||||||
|
num_inference_steps (`int`):
|
||||||
|
The number of inference steps to generate the noise schedule for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.Tensor`:
|
||||||
|
The converted lambda values following the Lu noise schedule.
|
||||||
|
"""
|
||||||
|
|
||||||
lambda_min: float = in_lambdas[-1].item()
|
lambda_min: float = in_lambdas[-1].item()
|
||||||
lambda_max: float = in_lambdas[0].item()
|
lambda_max: float = in_lambdas[0].item()
|
||||||
@ -1069,7 +1098,22 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
)
|
)
|
||||||
return x_t
|
return x_t
|
||||||
|
|
||||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
def index_for_timestep(
|
||||||
|
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Find the index for a given timestep in the schedule.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timestep (`int` or `torch.Tensor`):
|
||||||
|
The timestep for which to find the index.
|
||||||
|
schedule_timesteps (`torch.Tensor`, *optional*):
|
||||||
|
The timestep schedule to search in. If `None`, uses `self.timesteps`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`int`:
|
||||||
|
The index of the timestep in the schedule.
|
||||||
|
"""
|
||||||
if schedule_timesteps is None:
|
if schedule_timesteps is None:
|
||||||
schedule_timesteps = self.timesteps
|
schedule_timesteps = self.timesteps
|
||||||
|
|
||||||
@ -1088,9 +1132,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
|
|
||||||
return step_index
|
return step_index
|
||||||
|
|
||||||
def _init_step_index(self, timestep):
|
def _init_step_index(self, timestep: Union[int, torch.Tensor]) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize the step_index counter for the scheduler.
|
Initialize the step_index counter for the scheduler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timestep (`int` or `torch.Tensor`):
|
||||||
|
The current timestep for which to initialize the step index.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.begin_index is None:
|
if self.begin_index is None:
|
||||||
@ -1105,7 +1153,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
model_output: torch.Tensor,
|
model_output: torch.Tensor,
|
||||||
timestep: Union[int, torch.Tensor],
|
timestep: Union[int, torch.Tensor],
|
||||||
sample: torch.Tensor,
|
sample: torch.Tensor,
|
||||||
generator=None,
|
generator: Optional[torch.Generator] = None,
|
||||||
variance_noise: Optional[torch.Tensor] = None,
|
variance_noise: Optional[torch.Tensor] = None,
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
) -> Union[SchedulerOutput, Tuple]:
|
) -> Union[SchedulerOutput, Tuple]:
|
||||||
@ -1115,22 +1163,22 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_output (`torch.Tensor`):
|
model_output (`torch.Tensor`):
|
||||||
The direct output from learned diffusion model.
|
The direct output from the learned diffusion model.
|
||||||
timestep (`int`):
|
timestep (`int` or `torch.Tensor`):
|
||||||
The current discrete timestep in the diffusion chain.
|
The current discrete timestep in the diffusion chain.
|
||||||
sample (`torch.Tensor`):
|
sample (`torch.Tensor`):
|
||||||
A current instance of a sample created by the diffusion process.
|
A current instance of a sample created by the diffusion process.
|
||||||
generator (`torch.Generator`, *optional*):
|
generator (`torch.Generator`, *optional*):
|
||||||
A random number generator.
|
A random number generator.
|
||||||
variance_noise (`torch.Tensor`):
|
variance_noise (`torch.Tensor`, *optional*):
|
||||||
Alternative to generating noise with `generator` by directly providing the noise for the variance
|
Alternative to generating noise with `generator` by directly providing the noise for the variance
|
||||||
itself. Useful for methods such as [`LEdits++`].
|
itself. Useful for methods such as [`LEdits++`].
|
||||||
return_dict (`bool`):
|
return_dict (`bool`, defaults to `True`):
|
||||||
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
|
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
||||||
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
|
If `return_dict` is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
|
||||||
tuple is returned where the first element is the sample tensor.
|
tuple is returned where the first element is the sample tensor.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -1210,6 +1258,21 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
noise: torch.Tensor,
|
noise: torch.Tensor,
|
||||||
timesteps: torch.IntTensor,
|
timesteps: torch.IntTensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Add noise to the original samples according to the noise schedule at the specified timesteps.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
original_samples (`torch.Tensor`):
|
||||||
|
The original samples without noise.
|
||||||
|
noise (`torch.Tensor`):
|
||||||
|
The noise to add to the samples.
|
||||||
|
timesteps (`torch.IntTensor`):
|
||||||
|
The timesteps at which to add noise to the samples.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.Tensor`:
|
||||||
|
The noisy samples.
|
||||||
|
"""
|
||||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||||
|
|||||||
@ -413,6 +413,17 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||||
|
"""
|
||||||
|
Convert sigma values to alpha_t and sigma_t values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sigma (`torch.Tensor`):
|
||||||
|
The sigma value(s) to convert.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Tuple[torch.Tensor, torch.Tensor]`:
|
||||||
|
A tuple containing (alpha_t, sigma_t) values.
|
||||||
|
"""
|
||||||
if self.config.use_flow_sigmas:
|
if self.config.use_flow_sigmas:
|
||||||
alpha_t = 1 - sigma
|
alpha_t = 1 - sigma
|
||||||
sigma_t = sigma
|
sigma_t = sigma
|
||||||
|
|||||||
@ -491,6 +491,17 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||||
|
"""
|
||||||
|
Convert sigma values to alpha_t and sigma_t values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sigma (`torch.Tensor`):
|
||||||
|
The sigma value(s) to convert.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Tuple[torch.Tensor, torch.Tensor]`:
|
||||||
|
A tuple containing (alpha_t, sigma_t) values.
|
||||||
|
"""
|
||||||
if self.config.use_flow_sigmas:
|
if self.config.use_flow_sigmas:
|
||||||
alpha_t = 1 - sigma
|
alpha_t = 1 - sigma
|
||||||
sigma_t = sigma
|
sigma_t = sigma
|
||||||
@ -1079,7 +1090,22 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
raise ValueError(f"Order must be 1, 2, 3, got {order}")
|
raise ValueError(f"Order must be 1, 2, 3, got {order}")
|
||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
def index_for_timestep(
|
||||||
|
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Find the index for a given timestep in the schedule.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timestep (`int` or `torch.Tensor`):
|
||||||
|
The timestep for which to find the index.
|
||||||
|
schedule_timesteps (`torch.Tensor`, *optional*):
|
||||||
|
The timestep schedule to search in. If `None`, uses `self.timesteps`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`int`:
|
||||||
|
The index of the timestep in the schedule.
|
||||||
|
"""
|
||||||
if schedule_timesteps is None:
|
if schedule_timesteps is None:
|
||||||
schedule_timesteps = self.timesteps
|
schedule_timesteps = self.timesteps
|
||||||
|
|
||||||
@ -1102,6 +1128,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
def _init_step_index(self, timestep):
|
def _init_step_index(self, timestep):
|
||||||
"""
|
"""
|
||||||
Initialize the step_index counter for the scheduler.
|
Initialize the step_index counter for the scheduler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timestep (`int` or `torch.Tensor`):
|
||||||
|
The current timestep for which to initialize the step index.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.begin_index is None:
|
if self.begin_index is None:
|
||||||
@ -1204,6 +1234,21 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
noise: torch.Tensor,
|
noise: torch.Tensor,
|
||||||
timesteps: torch.IntTensor,
|
timesteps: torch.IntTensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Add noise to the original samples according to the noise schedule at the specified timesteps.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
original_samples (`torch.Tensor`):
|
||||||
|
The original samples without noise.
|
||||||
|
noise (`torch.Tensor`):
|
||||||
|
The noise to add to the samples.
|
||||||
|
timesteps (`torch.IntTensor`):
|
||||||
|
The timesteps at which to add noise to the samples.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.Tensor`:
|
||||||
|
The noisy samples.
|
||||||
|
"""
|
||||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||||
|
|||||||
@ -578,7 +578,22 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
return x_t
|
return x_t
|
||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
def index_for_timestep(
|
||||||
|
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Find the index for a given timestep in the schedule.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timestep (`int` or `torch.Tensor`):
|
||||||
|
The timestep for which to find the index.
|
||||||
|
schedule_timesteps (`torch.Tensor`, *optional*):
|
||||||
|
The timestep schedule to search in. If `None`, uses `self.timesteps`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`int`:
|
||||||
|
The index of the timestep in the schedule.
|
||||||
|
"""
|
||||||
if schedule_timesteps is None:
|
if schedule_timesteps is None:
|
||||||
schedule_timesteps = self.timesteps
|
schedule_timesteps = self.timesteps
|
||||||
|
|
||||||
@ -601,6 +616,10 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
def _init_step_index(self, timestep):
|
def _init_step_index(self, timestep):
|
||||||
"""
|
"""
|
||||||
Initialize the step_index counter for the scheduler.
|
Initialize the step_index counter for the scheduler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timestep (`int` or `torch.Tensor`):
|
||||||
|
The current timestep for which to initialize the step index.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.begin_index is None:
|
if self.begin_index is None:
|
||||||
|
|||||||
@ -423,6 +423,17 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||||
|
"""
|
||||||
|
Convert sigma values to alpha_t and sigma_t values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sigma (`torch.Tensor`):
|
||||||
|
The sigma value(s) to convert.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Tuple[torch.Tensor, torch.Tensor]`:
|
||||||
|
A tuple containing (alpha_t, sigma_t) values.
|
||||||
|
"""
|
||||||
if self.config.use_flow_sigmas:
|
if self.config.use_flow_sigmas:
|
||||||
alpha_t = 1 - sigma
|
alpha_t = 1 - sigma
|
||||||
sigma_t = sigma
|
sigma_t = sigma
|
||||||
@ -1103,7 +1114,22 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
return x_t
|
return x_t
|
||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
def index_for_timestep(
|
||||||
|
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Find the index for a given timestep in the schedule.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timestep (`int` or `torch.Tensor`):
|
||||||
|
The timestep for which to find the index.
|
||||||
|
schedule_timesteps (`torch.Tensor`, *optional*):
|
||||||
|
The timestep schedule to search in. If `None`, uses `self.timesteps`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`int`:
|
||||||
|
The index of the timestep in the schedule.
|
||||||
|
"""
|
||||||
if schedule_timesteps is None:
|
if schedule_timesteps is None:
|
||||||
schedule_timesteps = self.timesteps
|
schedule_timesteps = self.timesteps
|
||||||
|
|
||||||
@ -1126,6 +1152,10 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
def _init_step_index(self, timestep):
|
def _init_step_index(self, timestep):
|
||||||
"""
|
"""
|
||||||
Initialize the step_index counter for the scheduler.
|
Initialize the step_index counter for the scheduler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timestep (`int` or `torch.Tensor`):
|
||||||
|
The current timestep for which to initialize the step index.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.begin_index is None:
|
if self.begin_index is None:
|
||||||
|
|||||||
@ -513,6 +513,17 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||||
|
"""
|
||||||
|
Convert sigma values to alpha_t and sigma_t values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sigma (`torch.Tensor`):
|
||||||
|
The sigma value(s) to convert.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Tuple[torch.Tensor, torch.Tensor]`:
|
||||||
|
A tuple containing (alpha_t, sigma_t) values.
|
||||||
|
"""
|
||||||
if self.config.use_flow_sigmas:
|
if self.config.use_flow_sigmas:
|
||||||
alpha_t = 1 - sigma
|
alpha_t = 1 - sigma
|
||||||
sigma_t = sigma
|
sigma_t = sigma
|
||||||
@ -984,7 +995,22 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
return x_t
|
return x_t
|
||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
def index_for_timestep(
|
||||||
|
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Find the index for a given timestep in the schedule.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timestep (`int` or `torch.Tensor`):
|
||||||
|
The timestep for which to find the index.
|
||||||
|
schedule_timesteps (`torch.Tensor`, *optional*):
|
||||||
|
The timestep schedule to search in. If `None`, uses `self.timesteps`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`int`:
|
||||||
|
The index of the timestep in the schedule.
|
||||||
|
"""
|
||||||
if schedule_timesteps is None:
|
if schedule_timesteps is None:
|
||||||
schedule_timesteps = self.timesteps
|
schedule_timesteps = self.timesteps
|
||||||
|
|
||||||
@ -1007,6 +1033,10 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
def _init_step_index(self, timestep):
|
def _init_step_index(self, timestep):
|
||||||
"""
|
"""
|
||||||
Initialize the step_index counter for the scheduler.
|
Initialize the step_index counter for the scheduler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timestep (`int` or `torch.Tensor`):
|
||||||
|
The current timestep for which to initialize the step index.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.begin_index is None:
|
if self.begin_index is None:
|
||||||
@ -1119,6 +1149,21 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
noise: torch.Tensor,
|
noise: torch.Tensor,
|
||||||
timesteps: torch.IntTensor,
|
timesteps: torch.IntTensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Add noise to the original samples according to the noise schedule at the specified timesteps.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
original_samples (`torch.Tensor`):
|
||||||
|
The original samples without noise.
|
||||||
|
noise (`torch.Tensor`):
|
||||||
|
The noise to add to the samples.
|
||||||
|
timesteps (`torch.IntTensor`):
|
||||||
|
The timesteps at which to add noise to the samples.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.Tensor`:
|
||||||
|
The noisy samples.
|
||||||
|
"""
|
||||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user