Improve docstrings and type hints in scheduling_amused.py (#12623)
* Improve docstrings and type hints in scheduling_amused.py - Add complete type hints for helper functions (gumbel_noise, mask_by_random_topk) - Enhance AmusedSchedulerOutput with proper Optional typing - Add comprehensive docstrings for AmusedScheduler class - Improve __init__, set_timesteps, step, and add_noise methods - Fix type hints to match documentation conventions - All changes follow project standards from issue #9567 * Enhance type hints and docstrings in scheduling_amused.py - Update type hints for `prev_sample` and `pred_original_sample` in `AmusedSchedulerOutput` to reflect their tensor types. - Improve docstring for `gumbel_noise` to specify the output tensor's dtype and device. - Refine `AmusedScheduler` class documentation, including detailed descriptions of the masking schedule and temperature parameters. - Adjust type hints in `set_timesteps` and `step` methods for better clarity and consistency. * Apply review feedback on scheduling_amused.py - Replace generic [Amused] reference with specific [`AmusedPipeline`] reference for consistency with project documentation conventions
This commit is contained in:
parent
d6c63bb956
commit
44c3101685
@ -1,6 +1,6 @@
|
|||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -9,13 +9,48 @@ from ..utils import BaseOutput
|
|||||||
from .scheduling_utils import SchedulerMixin
|
from .scheduling_utils import SchedulerMixin
|
||||||
|
|
||||||
|
|
||||||
def gumbel_noise(t, generator=None):
|
def gumbel_noise(t: torch.Tensor, generator: Optional[torch.Generator] = None) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Generate Gumbel noise for sampling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
t (`torch.Tensor`):
|
||||||
|
Input tensor to match the shape and dtype of the output noise.
|
||||||
|
generator (`torch.Generator`, *optional*):
|
||||||
|
A random number generator for reproducible sampling.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.Tensor`:
|
||||||
|
Gumbel-distributed noise with the same shape, dtype, and device as the input tensor.
|
||||||
|
"""
|
||||||
device = generator.device if generator is not None else t.device
|
device = generator.device if generator is not None else t.device
|
||||||
noise = torch.zeros_like(t, device=device).uniform_(0, 1, generator=generator).to(t.device)
|
noise = torch.zeros_like(t, device=device).uniform_(0, 1, generator=generator).to(t.device)
|
||||||
return -torch.log((-torch.log(noise.clamp(1e-20))).clamp(1e-20))
|
return -torch.log((-torch.log(noise.clamp(1e-20))).clamp(1e-20))
|
||||||
|
|
||||||
|
|
||||||
def mask_by_random_topk(mask_len, probs, temperature=1.0, generator=None):
|
def mask_by_random_topk(
|
||||||
|
mask_len: torch.Tensor,
|
||||||
|
probs: torch.Tensor,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
generator: Optional[torch.Generator] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Mask tokens by selecting the top-k lowest confidence scores with temperature-based randomness.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mask_len (`torch.Tensor`):
|
||||||
|
Number of tokens to mask per sample in the batch.
|
||||||
|
probs (`torch.Tensor`):
|
||||||
|
Probability scores for each token.
|
||||||
|
temperature (`float`, *optional*, defaults to 1.0):
|
||||||
|
Temperature parameter for controlling randomness in the masking process.
|
||||||
|
generator (`torch.Generator`, *optional*):
|
||||||
|
A random number generator for reproducible sampling.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.Tensor`:
|
||||||
|
Boolean mask indicating which tokens should be masked.
|
||||||
|
"""
|
||||||
confidence = torch.log(probs.clamp(1e-20)) + temperature * gumbel_noise(probs, generator=generator)
|
confidence = torch.log(probs.clamp(1e-20)) + temperature * gumbel_noise(probs, generator=generator)
|
||||||
sorted_confidence = torch.sort(confidence, dim=-1).values
|
sorted_confidence = torch.sort(confidence, dim=-1).values
|
||||||
cut_off = torch.gather(sorted_confidence, 1, mask_len.long())
|
cut_off = torch.gather(sorted_confidence, 1, mask_len.long())
|
||||||
@ -29,28 +64,46 @@ class AmusedSchedulerOutput(BaseOutput):
|
|||||||
Output class for the scheduler's `step` function output.
|
Output class for the scheduler's `step` function output.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
prev_sample (`torch.LongTensor` of shape `(batch_size, height, width)` or `(batch_size, sequence_length)`):
|
||||||
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
Computed sample `(x_{t-1})` of previous timestep with token IDs. `prev_sample` should be used as next model
|
||||||
denoising loop.
|
input in the denoising loop.
|
||||||
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
pred_original_sample (`torch.LongTensor` of shape `(batch_size, height, width)` or `(batch_size, sequence_length)`, *optional*):
|
||||||
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
The predicted fully denoised sample `(x_{0})` with token IDs based on the model output from the current
|
||||||
`pred_original_sample` can be used to preview progress or for guidance.
|
timestep. `pred_original_sample` can be used to preview progress or for guidance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
prev_sample: torch.Tensor
|
prev_sample: torch.Tensor
|
||||||
pred_original_sample: torch.Tensor = None
|
pred_original_sample: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
|
||||||
class AmusedScheduler(SchedulerMixin, ConfigMixin):
|
class AmusedScheduler(SchedulerMixin, ConfigMixin):
|
||||||
|
"""
|
||||||
|
A scheduler for masked token generation as used in [`AmusedPipeline`].
|
||||||
|
|
||||||
|
This scheduler iteratively unmasks tokens based on their confidence scores, following either a cosine or linear
|
||||||
|
schedule. Unlike traditional diffusion schedulers that work with continuous pixel values, this scheduler operates
|
||||||
|
on discrete token IDs, making it suitable for autoregressive and non-autoregressive masked token generation models.
|
||||||
|
|
||||||
|
This scheduler inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the
|
||||||
|
generic methods the library implements for all schedulers such as loading and saving.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mask_token_id (`int`):
|
||||||
|
The token ID used to represent masked tokens in the sequence.
|
||||||
|
masking_schedule (`Literal["cosine", "linear"]`, *optional*, defaults to `"cosine"`):
|
||||||
|
The schedule type for determining the mask ratio at each timestep. Can be either `"cosine"` or `"linear"`.
|
||||||
|
"""
|
||||||
|
|
||||||
order = 1
|
order = 1
|
||||||
|
|
||||||
temperatures: torch.Tensor
|
temperatures: Optional[torch.Tensor]
|
||||||
|
timesteps: Optional[torch.Tensor]
|
||||||
|
|
||||||
@register_to_config
|
@register_to_config
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
mask_token_id: int,
|
mask_token_id: int,
|
||||||
masking_schedule: str = "cosine",
|
masking_schedule: Literal["cosine", "linear"] = "cosine",
|
||||||
):
|
):
|
||||||
self.temperatures = None
|
self.temperatures = None
|
||||||
self.timesteps = None
|
self.timesteps = None
|
||||||
@ -58,9 +111,23 @@ class AmusedScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
def set_timesteps(
|
def set_timesteps(
|
||||||
self,
|
self,
|
||||||
num_inference_steps: int,
|
num_inference_steps: int,
|
||||||
temperature: Union[int, Tuple[int, int], List[int]] = (2, 0),
|
temperature: Union[float, Tuple[float, float], List[float]] = (2, 0),
|
||||||
device: Union[str, torch.device] = None,
|
device: Optional[Union[str, torch.device]] = None,
|
||||||
):
|
) -> None:
|
||||||
|
"""
|
||||||
|
Set the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_inference_steps (`int`):
|
||||||
|
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||||
|
temperature (`Union[float, Tuple[float, float], List[float]]`, *optional*, defaults to `(2, 0)`):
|
||||||
|
Temperature parameter(s) for controlling the randomness of sampling. If a tuple or list is provided,
|
||||||
|
temperatures will be linearly interpolated between the first and second values across all timesteps. If
|
||||||
|
a single value is provided, temperatures will be linearly interpolated from that value to 0.01.
|
||||||
|
device (`str` or `torch.device`, *optional*):
|
||||||
|
The device to which the timesteps and temperatures should be moved to. If `None`, the timesteps are not
|
||||||
|
moved.
|
||||||
|
"""
|
||||||
self.timesteps = torch.arange(num_inference_steps, device=device).flip(0)
|
self.timesteps = torch.arange(num_inference_steps, device=device).flip(0)
|
||||||
|
|
||||||
if isinstance(temperature, (tuple, list)):
|
if isinstance(temperature, (tuple, list)):
|
||||||
@ -71,12 +138,38 @@ class AmusedScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
def step(
|
def step(
|
||||||
self,
|
self,
|
||||||
model_output: torch.Tensor,
|
model_output: torch.Tensor,
|
||||||
timestep: torch.long,
|
timestep: int,
|
||||||
sample: torch.LongTensor,
|
sample: torch.LongTensor,
|
||||||
starting_mask_ratio: int = 1,
|
starting_mask_ratio: float = 1.0,
|
||||||
generator: Optional[torch.Generator] = None,
|
generator: Optional[torch.Generator] = None,
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
) -> Union[AmusedSchedulerOutput, Tuple]:
|
) -> Union[AmusedSchedulerOutput, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Predict the sample at the previous timestep by masking tokens based on confidence scores.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_output (`torch.Tensor`):
|
||||||
|
The direct output from the learned diffusion model. Typically of shape `(batch_size, num_tokens,
|
||||||
|
codebook_size)` or `(batch_size, codebook_size, height, width)` for 2D inputs.
|
||||||
|
timestep (`int`):
|
||||||
|
The current discrete timestep in the diffusion chain.
|
||||||
|
sample (`torch.LongTensor`):
|
||||||
|
A current instance of a sample created by the diffusion process. Contains token IDs, with masked
|
||||||
|
positions indicated by `mask_token_id`.
|
||||||
|
starting_mask_ratio (`float`, *optional*, defaults to 1.0):
|
||||||
|
A multiplier applied to the mask ratio schedule. Values less than 1.0 will result in fewer tokens being
|
||||||
|
masked at each step.
|
||||||
|
generator (`torch.Generator`, *optional*):
|
||||||
|
A random number generator for reproducible sampling.
|
||||||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to return an [`~schedulers.scheduling_amused.AmusedSchedulerOutput`] or a plain tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~schedulers.scheduling_amused.AmusedSchedulerOutput`] or `tuple`:
|
||||||
|
If `return_dict` is `True`, [`~schedulers.scheduling_amused.AmusedSchedulerOutput`] is returned,
|
||||||
|
otherwise a tuple is returned where the first element is the sample tensor (`prev_sample`) and the
|
||||||
|
second element is the predicted original sample tensor (`pred_original_sample`).
|
||||||
|
"""
|
||||||
two_dim_input = sample.ndim == 3 and model_output.ndim == 4
|
two_dim_input = sample.ndim == 3 and model_output.ndim == 4
|
||||||
|
|
||||||
if two_dim_input:
|
if two_dim_input:
|
||||||
@ -137,7 +230,27 @@ class AmusedScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
|
|
||||||
return AmusedSchedulerOutput(prev_sample, pred_original_sample)
|
return AmusedSchedulerOutput(prev_sample, pred_original_sample)
|
||||||
|
|
||||||
def add_noise(self, sample, timesteps, generator=None):
|
def add_noise(
|
||||||
|
self,
|
||||||
|
sample: torch.LongTensor,
|
||||||
|
timesteps: int,
|
||||||
|
generator: Optional[torch.Generator] = None,
|
||||||
|
) -> torch.LongTensor:
|
||||||
|
"""
|
||||||
|
Add noise to a sample by randomly masking tokens according to the masking schedule.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample (`torch.LongTensor`):
|
||||||
|
The input sample containing token IDs to be partially masked.
|
||||||
|
timesteps (`int`):
|
||||||
|
The timestep that determines how much masking to apply. Higher timesteps result in more masking.
|
||||||
|
generator (`torch.Generator`, *optional*):
|
||||||
|
A random number generator for reproducible masking.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.LongTensor`:
|
||||||
|
The sample with some tokens replaced by `mask_token_id` according to the masking schedule.
|
||||||
|
"""
|
||||||
step_idx = (self.timesteps == timesteps).nonzero()
|
step_idx = (self.timesteps == timesteps).nonzero()
|
||||||
ratio = (step_idx + 1) / len(self.timesteps)
|
ratio = (step_idx + 1) / len(self.timesteps)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user