Improve docstrings and type hints in scheduling_amused.py (#12623)

* Improve docstrings and type hints in scheduling_amused.py

- Add complete type hints for helper functions (gumbel_noise, mask_by_random_topk)
- Enhance AmusedSchedulerOutput with proper Optional typing
- Add comprehensive docstrings for AmusedScheduler class
- Improve __init__, set_timesteps, step, and add_noise methods
- Fix type hints to match documentation conventions
- All changes follow project standards from issue #9567

* Enhance type hints and docstrings in scheduling_amused.py

- Update type hints for `prev_sample` and `pred_original_sample` in `AmusedSchedulerOutput` to reflect their tensor types.
- Improve docstring for `gumbel_noise` to specify the output tensor's dtype and device.
- Refine `AmusedScheduler` class documentation, including detailed descriptions of the masking schedule and temperature parameters.
- Adjust type hints in `set_timesteps` and `step` methods for better clarity and consistency.

* Apply review feedback on scheduling_amused.py

- Replace generic [Amused] reference with specific [`AmusedPipeline`] reference for consistency with project documentation conventions
This commit is contained in:
David El Malih 2025-11-13 02:26:10 +01:00 committed by GitHub
parent d6c63bb956
commit 44c3101685
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,6 @@
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import List, Literal, Optional, Tuple, Union
import torch
@ -9,13 +9,48 @@ from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin
def gumbel_noise(t, generator=None):
def gumbel_noise(t: torch.Tensor, generator: Optional[torch.Generator] = None) -> torch.Tensor:
"""
Generate Gumbel noise for sampling.
Args:
t (`torch.Tensor`):
Input tensor to match the shape and dtype of the output noise.
generator (`torch.Generator`, *optional*):
A random number generator for reproducible sampling.
Returns:
`torch.Tensor`:
Gumbel-distributed noise with the same shape, dtype, and device as the input tensor.
"""
device = generator.device if generator is not None else t.device
noise = torch.zeros_like(t, device=device).uniform_(0, 1, generator=generator).to(t.device)
return -torch.log((-torch.log(noise.clamp(1e-20))).clamp(1e-20))
def mask_by_random_topk(mask_len, probs, temperature=1.0, generator=None):
def mask_by_random_topk(
mask_len: torch.Tensor,
probs: torch.Tensor,
temperature: float = 1.0,
generator: Optional[torch.Generator] = None,
) -> torch.Tensor:
"""
Mask tokens by selecting the top-k lowest confidence scores with temperature-based randomness.
Args:
mask_len (`torch.Tensor`):
Number of tokens to mask per sample in the batch.
probs (`torch.Tensor`):
Probability scores for each token.
temperature (`float`, *optional*, defaults to 1.0):
Temperature parameter for controlling randomness in the masking process.
generator (`torch.Generator`, *optional*):
A random number generator for reproducible sampling.
Returns:
`torch.Tensor`:
Boolean mask indicating which tokens should be masked.
"""
confidence = torch.log(probs.clamp(1e-20)) + temperature * gumbel_noise(probs, generator=generator)
sorted_confidence = torch.sort(confidence, dim=-1).values
cut_off = torch.gather(sorted_confidence, 1, mask_len.long())
@ -29,28 +64,46 @@ class AmusedSchedulerOutput(BaseOutput):
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
prev_sample (`torch.LongTensor` of shape `(batch_size, height, width)` or `(batch_size, sequence_length)`):
Computed sample `(x_{t-1})` of previous timestep with token IDs. `prev_sample` should be used as next model
input in the denoising loop.
pred_original_sample (`torch.LongTensor` of shape `(batch_size, height, width)` or `(batch_size, sequence_length)`, *optional*):
The predicted fully denoised sample `(x_{0})` with token IDs based on the model output from the current
timestep. `pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.Tensor
pred_original_sample: torch.Tensor = None
pred_original_sample: Optional[torch.Tensor] = None
class AmusedScheduler(SchedulerMixin, ConfigMixin):
"""
A scheduler for masked token generation as used in [`AmusedPipeline`].
This scheduler iteratively unmasks tokens based on their confidence scores, following either a cosine or linear
schedule. Unlike traditional diffusion schedulers that work with continuous pixel values, this scheduler operates
on discrete token IDs, making it suitable for autoregressive and non-autoregressive masked token generation models.
This scheduler inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the
generic methods the library implements for all schedulers such as loading and saving.
Args:
mask_token_id (`int`):
The token ID used to represent masked tokens in the sequence.
masking_schedule (`Literal["cosine", "linear"]`, *optional*, defaults to `"cosine"`):
The schedule type for determining the mask ratio at each timestep. Can be either `"cosine"` or `"linear"`.
"""
order = 1
temperatures: torch.Tensor
temperatures: Optional[torch.Tensor]
timesteps: Optional[torch.Tensor]
@register_to_config
def __init__(
self,
mask_token_id: int,
masking_schedule: str = "cosine",
masking_schedule: Literal["cosine", "linear"] = "cosine",
):
self.temperatures = None
self.timesteps = None
@ -58,9 +111,23 @@ class AmusedScheduler(SchedulerMixin, ConfigMixin):
def set_timesteps(
self,
num_inference_steps: int,
temperature: Union[int, Tuple[int, int], List[int]] = (2, 0),
device: Union[str, torch.device] = None,
):
temperature: Union[float, Tuple[float, float], List[float]] = (2, 0),
device: Optional[Union[str, torch.device]] = None,
) -> None:
"""
Set the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
temperature (`Union[float, Tuple[float, float], List[float]]`, *optional*, defaults to `(2, 0)`):
Temperature parameter(s) for controlling the randomness of sampling. If a tuple or list is provided,
temperatures will be linearly interpolated between the first and second values across all timesteps. If
a single value is provided, temperatures will be linearly interpolated from that value to 0.01.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps and temperatures should be moved to. If `None`, the timesteps are not
moved.
"""
self.timesteps = torch.arange(num_inference_steps, device=device).flip(0)
if isinstance(temperature, (tuple, list)):
@ -71,12 +138,38 @@ class AmusedScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
timestep: torch.long,
timestep: int,
sample: torch.LongTensor,
starting_mask_ratio: int = 1,
starting_mask_ratio: float = 1.0,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[AmusedSchedulerOutput, Tuple]:
) -> Union[AmusedSchedulerOutput, Tuple[torch.Tensor, torch.Tensor]]:
"""
Predict the sample at the previous timestep by masking tokens based on confidence scores.
Args:
model_output (`torch.Tensor`):
The direct output from the learned diffusion model. Typically of shape `(batch_size, num_tokens,
codebook_size)` or `(batch_size, codebook_size, height, width)` for 2D inputs.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.LongTensor`):
A current instance of a sample created by the diffusion process. Contains token IDs, with masked
positions indicated by `mask_token_id`.
starting_mask_ratio (`float`, *optional*, defaults to 1.0):
A multiplier applied to the mask ratio schedule. Values less than 1.0 will result in fewer tokens being
masked at each step.
generator (`torch.Generator`, *optional*):
A random number generator for reproducible sampling.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return an [`~schedulers.scheduling_amused.AmusedSchedulerOutput`] or a plain tuple.
Returns:
[`~schedulers.scheduling_amused.AmusedSchedulerOutput`] or `tuple`:
If `return_dict` is `True`, [`~schedulers.scheduling_amused.AmusedSchedulerOutput`] is returned,
otherwise a tuple is returned where the first element is the sample tensor (`prev_sample`) and the
second element is the predicted original sample tensor (`pred_original_sample`).
"""
two_dim_input = sample.ndim == 3 and model_output.ndim == 4
if two_dim_input:
@ -137,7 +230,27 @@ class AmusedScheduler(SchedulerMixin, ConfigMixin):
return AmusedSchedulerOutput(prev_sample, pred_original_sample)
def add_noise(self, sample, timesteps, generator=None):
def add_noise(
self,
sample: torch.LongTensor,
timesteps: int,
generator: Optional[torch.Generator] = None,
) -> torch.LongTensor:
"""
Add noise to a sample by randomly masking tokens according to the masking schedule.
Args:
sample (`torch.LongTensor`):
The input sample containing token IDs to be partially masked.
timesteps (`int`):
The timestep that determines how much masking to apply. Higher timesteps result in more masking.
generator (`torch.Generator`, *optional*):
A random number generator for reproducible masking.
Returns:
`torch.LongTensor`:
The sample with some tokens replaced by `mask_token_id` according to the masking schedule.
"""
step_idx = (self.timesteps == timesteps).nonzero()
ratio = (step_idx + 1) / len(self.timesteps)