Improve docstrings and type hints in scheduling_amused.py (#12623)
* Improve docstrings and type hints in scheduling_amused.py - Add complete type hints for helper functions (gumbel_noise, mask_by_random_topk) - Enhance AmusedSchedulerOutput with proper Optional typing - Add comprehensive docstrings for AmusedScheduler class - Improve __init__, set_timesteps, step, and add_noise methods - Fix type hints to match documentation conventions - All changes follow project standards from issue #9567 * Enhance type hints and docstrings in scheduling_amused.py - Update type hints for `prev_sample` and `pred_original_sample` in `AmusedSchedulerOutput` to reflect their tensor types. - Improve docstring for `gumbel_noise` to specify the output tensor's dtype and device. - Refine `AmusedScheduler` class documentation, including detailed descriptions of the masking schedule and temperature parameters. - Adjust type hints in `set_timesteps` and `step` methods for better clarity and consistency. * Apply review feedback on scheduling_amused.py - Replace generic [Amused] reference with specific [`AmusedPipeline`] reference for consistency with project documentation conventions
This commit is contained in:
parent
d6c63bb956
commit
44c3101685
@ -1,6 +1,6 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Literal, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -9,13 +9,48 @@ from ..utils import BaseOutput
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
def gumbel_noise(t, generator=None):
|
||||
def gumbel_noise(t: torch.Tensor, generator: Optional[torch.Generator] = None) -> torch.Tensor:
|
||||
"""
|
||||
Generate Gumbel noise for sampling.
|
||||
|
||||
Args:
|
||||
t (`torch.Tensor`):
|
||||
Input tensor to match the shape and dtype of the output noise.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A random number generator for reproducible sampling.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
Gumbel-distributed noise with the same shape, dtype, and device as the input tensor.
|
||||
"""
|
||||
device = generator.device if generator is not None else t.device
|
||||
noise = torch.zeros_like(t, device=device).uniform_(0, 1, generator=generator).to(t.device)
|
||||
return -torch.log((-torch.log(noise.clamp(1e-20))).clamp(1e-20))
|
||||
|
||||
|
||||
def mask_by_random_topk(mask_len, probs, temperature=1.0, generator=None):
|
||||
def mask_by_random_topk(
|
||||
mask_len: torch.Tensor,
|
||||
probs: torch.Tensor,
|
||||
temperature: float = 1.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Mask tokens by selecting the top-k lowest confidence scores with temperature-based randomness.
|
||||
|
||||
Args:
|
||||
mask_len (`torch.Tensor`):
|
||||
Number of tokens to mask per sample in the batch.
|
||||
probs (`torch.Tensor`):
|
||||
Probability scores for each token.
|
||||
temperature (`float`, *optional*, defaults to 1.0):
|
||||
Temperature parameter for controlling randomness in the masking process.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A random number generator for reproducible sampling.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
Boolean mask indicating which tokens should be masked.
|
||||
"""
|
||||
confidence = torch.log(probs.clamp(1e-20)) + temperature * gumbel_noise(probs, generator=generator)
|
||||
sorted_confidence = torch.sort(confidence, dim=-1).values
|
||||
cut_off = torch.gather(sorted_confidence, 1, mask_len.long())
|
||||
@ -29,28 +64,46 @@ class AmusedSchedulerOutput(BaseOutput):
|
||||
Output class for the scheduler's `step` function output.
|
||||
|
||||
Args:
|
||||
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
||||
denoising loop.
|
||||
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
||||
`pred_original_sample` can be used to preview progress or for guidance.
|
||||
prev_sample (`torch.LongTensor` of shape `(batch_size, height, width)` or `(batch_size, sequence_length)`):
|
||||
Computed sample `(x_{t-1})` of previous timestep with token IDs. `prev_sample` should be used as next model
|
||||
input in the denoising loop.
|
||||
pred_original_sample (`torch.LongTensor` of shape `(batch_size, height, width)` or `(batch_size, sequence_length)`, *optional*):
|
||||
The predicted fully denoised sample `(x_{0})` with token IDs based on the model output from the current
|
||||
timestep. `pred_original_sample` can be used to preview progress or for guidance.
|
||||
"""
|
||||
|
||||
prev_sample: torch.Tensor
|
||||
pred_original_sample: torch.Tensor = None
|
||||
pred_original_sample: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
class AmusedScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
A scheduler for masked token generation as used in [`AmusedPipeline`].
|
||||
|
||||
This scheduler iteratively unmasks tokens based on their confidence scores, following either a cosine or linear
|
||||
schedule. Unlike traditional diffusion schedulers that work with continuous pixel values, this scheduler operates
|
||||
on discrete token IDs, making it suitable for autoregressive and non-autoregressive masked token generation models.
|
||||
|
||||
This scheduler inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the
|
||||
generic methods the library implements for all schedulers such as loading and saving.
|
||||
|
||||
Args:
|
||||
mask_token_id (`int`):
|
||||
The token ID used to represent masked tokens in the sequence.
|
||||
masking_schedule (`Literal["cosine", "linear"]`, *optional*, defaults to `"cosine"`):
|
||||
The schedule type for determining the mask ratio at each timestep. Can be either `"cosine"` or `"linear"`.
|
||||
"""
|
||||
|
||||
order = 1
|
||||
|
||||
temperatures: torch.Tensor
|
||||
temperatures: Optional[torch.Tensor]
|
||||
timesteps: Optional[torch.Tensor]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
mask_token_id: int,
|
||||
masking_schedule: str = "cosine",
|
||||
masking_schedule: Literal["cosine", "linear"] = "cosine",
|
||||
):
|
||||
self.temperatures = None
|
||||
self.timesteps = None
|
||||
@ -58,9 +111,23 @@ class AmusedScheduler(SchedulerMixin, ConfigMixin):
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
temperature: Union[int, Tuple[int, int], List[int]] = (2, 0),
|
||||
device: Union[str, torch.device] = None,
|
||||
):
|
||||
temperature: Union[float, Tuple[float, float], List[float]] = (2, 0),
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Set the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||
temperature (`Union[float, Tuple[float, float], List[float]]`, *optional*, defaults to `(2, 0)`):
|
||||
Temperature parameter(s) for controlling the randomness of sampling. If a tuple or list is provided,
|
||||
temperatures will be linearly interpolated between the first and second values across all timesteps. If
|
||||
a single value is provided, temperatures will be linearly interpolated from that value to 0.01.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps and temperatures should be moved to. If `None`, the timesteps are not
|
||||
moved.
|
||||
"""
|
||||
self.timesteps = torch.arange(num_inference_steps, device=device).flip(0)
|
||||
|
||||
if isinstance(temperature, (tuple, list)):
|
||||
@ -71,12 +138,38 @@ class AmusedScheduler(SchedulerMixin, ConfigMixin):
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: torch.long,
|
||||
timestep: int,
|
||||
sample: torch.LongTensor,
|
||||
starting_mask_ratio: int = 1,
|
||||
starting_mask_ratio: float = 1.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[AmusedSchedulerOutput, Tuple]:
|
||||
) -> Union[AmusedSchedulerOutput, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Predict the sample at the previous timestep by masking tokens based on confidence scores.
|
||||
|
||||
Args:
|
||||
model_output (`torch.Tensor`):
|
||||
The direct output from the learned diffusion model. Typically of shape `(batch_size, num_tokens,
|
||||
codebook_size)` or `(batch_size, codebook_size, height, width)` for 2D inputs.
|
||||
timestep (`int`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.LongTensor`):
|
||||
A current instance of a sample created by the diffusion process. Contains token IDs, with masked
|
||||
positions indicated by `mask_token_id`.
|
||||
starting_mask_ratio (`float`, *optional*, defaults to 1.0):
|
||||
A multiplier applied to the mask ratio schedule. Values less than 1.0 will result in fewer tokens being
|
||||
masked at each step.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A random number generator for reproducible sampling.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return an [`~schedulers.scheduling_amused.AmusedSchedulerOutput`] or a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~schedulers.scheduling_amused.AmusedSchedulerOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~schedulers.scheduling_amused.AmusedSchedulerOutput`] is returned,
|
||||
otherwise a tuple is returned where the first element is the sample tensor (`prev_sample`) and the
|
||||
second element is the predicted original sample tensor (`pred_original_sample`).
|
||||
"""
|
||||
two_dim_input = sample.ndim == 3 and model_output.ndim == 4
|
||||
|
||||
if two_dim_input:
|
||||
@ -137,7 +230,27 @@ class AmusedScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return AmusedSchedulerOutput(prev_sample, pred_original_sample)
|
||||
|
||||
def add_noise(self, sample, timesteps, generator=None):
|
||||
def add_noise(
|
||||
self,
|
||||
sample: torch.LongTensor,
|
||||
timesteps: int,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> torch.LongTensor:
|
||||
"""
|
||||
Add noise to a sample by randomly masking tokens according to the masking schedule.
|
||||
|
||||
Args:
|
||||
sample (`torch.LongTensor`):
|
||||
The input sample containing token IDs to be partially masked.
|
||||
timesteps (`int`):
|
||||
The timestep that determines how much masking to apply. Higher timesteps result in more masking.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A random number generator for reproducible masking.
|
||||
|
||||
Returns:
|
||||
`torch.LongTensor`:
|
||||
The sample with some tokens replaced by `mask_token_id` according to the masking schedule.
|
||||
"""
|
||||
step_idx = (self.timesteps == timesteps).nonzero()
|
||||
ratio = (step_idx + 1) / len(self.timesteps)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user