[BUG] Fix FP64 Gumbel precision coverage (#43150)

Signed-off-by: tianyu-z <zhangtianyupro@gmail.com>
Signed-off-by: Tianyu Zhang <53099276+tianyu-z@users.noreply.github.com>
Co-authored-by: Nick Hill <nickhill123@gmail.com>
Co-authored-by: OpenAI Codex <codex@openai.com>
This commit is contained in:
Tianyu Zhang
2026-06-05 04:04:14 -07:00
committed by GitHub
parent 8a83e6f2d7
commit 7fe7800fa4
11 changed files with 391 additions and 21 deletions
+64 -1
View File
@@ -544,13 +544,15 @@ def native_sample_recovered_tokens(
target_probs: torch.Tensor, # [num_tokens, vocab_size]
sampling_metadata: SamplingMetadata,
device: torch.device,
use_fp64_gumbel: bool = False,
) -> torch.Tensor:
batch_size = len(num_draft_tokens)
vocab_size = target_probs.shape[-1]
q_dtype = torch.float64 if use_fp64_gumbel else torch.float32
q = torch.empty(
(batch_size, vocab_size),
dtype=torch.float32,
dtype=q_dtype,
device=device,
)
q.exponential_()
@@ -935,6 +937,67 @@ def test_sample_recovered_tokens(
assert torch.equal(recovered_token_ids, ref_recovered_token_ids)
def test_sample_recovered_tokens_uses_fp64_exponential_race_when_requested():
batch_size = 2
vocab_size = 64
max_spec_len = 2
num_tokens = batch_size * max_spec_len
draft_probs = torch.rand(
num_tokens,
vocab_size,
dtype=torch.float32,
device=DEVICE_TYPE,
)
draft_probs = F.softmax(draft_probs, dim=-1)
target_probs = torch.rand(
num_tokens,
vocab_size,
dtype=torch.float32,
device=DEVICE_TYPE,
)
target_probs = F.softmax(target_probs, dim=-1)
draft_token_ids = torch.multinomial(draft_probs, num_samples=1).to(torch.int32)
generators = {
i: torch.Generator(device=DEVICE_TYPE).manual_seed(i) for i in range(batch_size)
}
sampling_metadata = create_sampling_metadata(
all_greedy=False,
temperature=torch.ones(batch_size, dtype=torch.float32, device=DEVICE_TYPE),
generators=generators,
)
spec_decode_metadata = create_spec_decode_metadata(
draft_token_ids.reshape(batch_size, max_spec_len).tolist(),
target_probs.log(),
)
expected = native_sample_recovered_tokens(
max_spec_len,
spec_decode_metadata.num_draft_tokens,
spec_decode_metadata.cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
sampling_metadata,
device=torch.device(DEVICE_TYPE),
use_fp64_gumbel=True,
)
actual = sample_recovered_tokens(
max_spec_len,
spec_decode_metadata.num_draft_tokens,
spec_decode_metadata.cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
sampling_metadata,
device=torch.device(DEVICE_TYPE),
use_fp64_gumbel=True,
)
assert torch.equal(actual, expected)
########################### Tests for Synthetic Rejection Sampling #########
+39 -1
View File
@@ -6,7 +6,12 @@ from torch import Generator
from vllm.platforms import current_platform
from vllm.triton_utils import HAS_TRITON
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_pytorch
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.sample.ops.topk_topp_sampler import (
apply_top_k_top_p_pytorch,
random_sample,
)
from vllm.v1.sample.sampler import Sampler
DEVICE_TYPE = current_platform.device_type
@@ -38,6 +43,10 @@ def _flashinfer_topk_topp_supported() -> bool:
FLASHINFER_TOPK_TOPP_SUPPORTED = _flashinfer_topk_topp_supported()
def _seed_default_generator(seed: int) -> None:
set_random_seed(seed)
@pytest.fixture(autouse=True)
def reset_default_device():
"""
@@ -49,6 +58,35 @@ def reset_default_device():
torch.set_default_device(original_device)
def test_sampler_threads_fp64_gumbel_to_topk_topp_sampler():
sampler = Sampler(use_fp64_gumbel=True)
assert sampler.topk_topp_sampler.use_fp64_gumbel
def test_random_sample_uses_fp64_exponential_race_when_requested():
torch.set_default_device(DEVICE_TYPE)
probs = torch.tensor(
[
[0.70, 0.20, 0.10],
[0.05, 0.15, 0.80],
[0.25, 0.25, 0.50],
],
dtype=torch.float32,
device=DEVICE_TYPE,
)
_seed_default_generator(12345)
q = torch.empty(probs.shape, dtype=torch.float64, device=probs.device)
q.exponential_()
expected = q.reciprocal_().mul_(probs).argmax(dim=-1).view(-1)
_seed_default_generator(12345)
actual = random_sample(probs.clone(), {}, use_fp64_gumbel=True)
assert torch.equal(actual, expected)
def test_topk_impl_equivalence():
torch.set_default_device(DEVICE_TYPE)
generator = Generator(device=DEVICE_TYPE).manual_seed(33)
+2 -1
View File
@@ -1034,7 +1034,8 @@ def test_propose_stores_probabilistic_draft_probs(monkeypatch):
proposer.model = model_mock
proposer._draft_attn_layer_names = {"layer.0"}
def fake_compute_probs(logits, sampling_metadata):
def fake_compute_probs(logits, sampling_metadata, use_fp64_gumbel):
assert use_fp64_gumbel == proposer.use_fp64_gumbel
probs = torch.softmax(logits, dim=-1)
return probs.argmax(dim=-1), probs
@@ -0,0 +1,70 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.sample.logits_processor import LogitsProcessors
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.llm_base_proposer import (
compute_probs_and_sample_next_token,
)
DEVICE_TYPE = current_platform.device_type
def _seed_default_generator(seed: int) -> None:
set_random_seed(seed)
def _make_sampling_metadata(batch_size: int) -> SamplingMetadata:
return SamplingMetadata(
temperature=torch.ones(batch_size, dtype=torch.float32, device=DEVICE_TYPE),
all_greedy=False,
all_random=True,
top_p=None,
top_k=None,
generators={},
max_num_logprobs=None,
no_penalties=True,
prompt_token_ids=None,
frequency_penalties=torch.empty(0, device=DEVICE_TYPE),
presence_penalties=torch.empty(0, device=DEVICE_TYPE),
repetition_penalties=torch.empty(0, device=DEVICE_TYPE),
output_token_ids=[[] for _ in range(batch_size)],
spec_token_ids=[[] for _ in range(batch_size)],
allowed_token_ids_mask=None,
bad_words_token_ids={},
logitsprocs=LogitsProcessors(),
)
def test_compute_probs_and_sample_next_token_uses_fp64_exponential_race():
batch_size = 4
vocab_size = 32
generator = torch.Generator(device=DEVICE_TYPE).manual_seed(11)
logits = torch.randn(
batch_size,
vocab_size,
dtype=torch.float32,
device=DEVICE_TYPE,
generator=generator,
)
metadata = _make_sampling_metadata(batch_size)
_seed_default_generator(12345)
probs = logits.softmax(dim=-1, dtype=torch.float32)
q = torch.empty(probs.shape, dtype=torch.float64, device=probs.device)
q.exponential_()
expected_ids = q.reciprocal_().mul_(probs).argmax(dim=-1).view(-1)
_seed_default_generator(12345)
actual_ids, actual_probs = compute_probs_and_sample_next_token(
logits.clone(),
metadata,
use_fp64_gumbel=True,
)
assert torch.equal(actual_ids, expected_ids)
assert torch.allclose(actual_probs, probs)
@@ -0,0 +1,141 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""CUDA proof for fp32 exponential-race tail truncation.
This script is intentionally not a unit test. It is a reproducible, GPU-only
statistical proof for the hidden Gumbel-max idiom:
q.exponential_()
sample = (probs / q).argmax()
For q ~ Exp(1), this is equivalent to argmax(log(probs) + Gumbel). On CUDA,
fp32 exponential samples inherit a 24-bit uniform lower-tail cutoff, so very
small q values are impossible. The many-tail experiment below chooses a case
where a correct sampler should select a low-probability tail token dozens of
times, while fp32 q cannot select one.
"""
from __future__ import annotations
import argparse
import math
import time
import torch
def _seed(seed: int) -> None:
torch.manual_seed(seed)
def measure_exponential_lower_tail(
*,
device: torch.device,
samples: int,
chunk_size: int,
seed: int,
) -> None:
threshold = 2.0**-24
print(f"lower-tail threshold: {threshold:.18e}")
for dtype in (torch.float32, torch.float64):
_seed(seed)
count_below = 0
min_q = float("inf")
max_q = 0.0
start = time.perf_counter()
remaining = samples
while remaining > 0:
n = min(chunk_size, remaining)
q = torch.empty((n,), dtype=dtype, device=device)
q.exponential_()
count_below += int((q < threshold).sum().item())
min_q = min(min_q, float(q.min().item()))
max_q = max(max_q, float(q.max().item()))
remaining -= n
torch.accelerator.synchronize()
elapsed = time.perf_counter() - start
print(
f"{dtype}: samples={samples} count(q < 2^-24)={count_below} "
f"min={min_q:.18e} max={max_q:.6f} elapsed={elapsed:.2f}s"
)
def run_many_tail_race(
*,
device: torch.device,
trials: int,
num_tail_tokens: int,
gap: float,
chunk_trials: int,
seed: int,
) -> None:
p_tail = math.exp(-gap)
expected_tail_hits = (
trials * (num_tail_tokens * p_tail) / (1.0 + num_tail_tokens * p_tail)
)
print(
"many-tail race: "
f"trials={trials} num_tail_tokens={num_tail_tokens} gap={gap} "
f"expected_tail_hits={expected_tail_hits:.4f}"
)
for dtype in (torch.float32, torch.float64):
_seed(seed)
hits = 0
p0 = torch.tensor(1.0, dtype=dtype, device=device)
pt = torch.tensor(p_tail, dtype=dtype, device=device)
start = time.perf_counter()
remaining = trials
while remaining > 0:
batch = min(chunk_trials, remaining)
q0 = torch.empty((batch,), dtype=dtype, device=device)
q0.exponential_()
qt = torch.empty((batch, num_tail_tokens), dtype=dtype, device=device)
qt.exponential_()
head_score = p0 / q0
tail_score = (pt / qt).amax(dim=-1)
hits += int((tail_score > head_score).sum().item())
remaining -= batch
torch.accelerator.synchronize()
elapsed = time.perf_counter() - start
print(f"{dtype}: tail_hits={hits} elapsed={elapsed:.2f}s")
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--lower-tail-samples", type=int, default=200_000_000)
parser.add_argument("--lower-tail-chunk-size", type=int, default=10_000_000)
parser.add_argument("--race-trials", type=int, default=100_000)
parser.add_argument("--race-tail-tokens", type=int, default=262_144)
parser.add_argument("--race-gap", type=float, default=20.5)
parser.add_argument("--race-chunk-trials", type=int, default=64)
parser.add_argument("--seed", type=int, default=2026)
args = parser.parse_args()
if not torch.accelerator.is_available():
raise RuntimeError("CUDA is required for this proof.")
device = torch.accelerator.current_accelerator()
if device.type != "cuda":
raise RuntimeError("CUDA is required for this proof.")
print(f"torch={torch.__version__} cuda={torch.version.cuda}")
print(f"device={device}")
measure_exponential_lower_tail(
device=device,
samples=args.lower_tail_samples,
chunk_size=args.lower_tail_chunk_size,
seed=args.seed,
)
run_many_tail_race(
device=device,
trials=args.race_trials,
num_tail_tokens=args.race_tail_tokens,
gap=args.race_gap,
chunk_trials=args.race_chunk_trials,
seed=args.seed,
)
if __name__ == "__main__":
main()
+4 -3
View File
@@ -235,9 +235,10 @@ class ModelConfig:
temperature and top_k/top_p.
"""
use_fp64_gumbel: bool = False
"""Whether to use FP64 (instead of FP32) for the Gumbel noise used by the
sampler. FP64 reduces the chance of ties in Gumbel-max sampling at the cost
of significantly lower kernel throughput on most GPUs."""
"""Whether to use FP64 (instead of FP32) random noise for Gumbel-max and
equivalent exponential-race sampling. FP64 preserves lower-tail sampling
events that fp32 uniform/exponential draws can truncate, at the cost of
significantly lower throughput on most GPUs."""
disable_sliding_window: bool = False
"""Whether to disable sliding window. If True, we will disable the sliding
window functionality of the model, capping to sliding window size. If the
+36 -7
View File
@@ -75,9 +75,14 @@ class TopKTopPSampler(nn.Module):
Implementations may update the logits tensor in-place.
"""
def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None:
def __init__(
self,
logprobs_mode: LogprobsMode = "raw_logprobs",
use_fp64_gumbel: bool = False,
) -> None:
super().__init__()
self.logprobs_mode = logprobs_mode
self.use_fp64_gumbel = use_fp64_gumbel
if current_platform.is_cuda():
# FlashInfer doesn't expose post-top-k/top-p logits/logprobs,
# so it can't be used when the configured mode requires them.
@@ -142,7 +147,10 @@ class TopKTopPSampler(nn.Module):
elif self.logprobs_mode == "processed_logprobs":
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators), logits_to_return
return (
random_sample(probs, generators, self.use_fp64_gumbel),
logits_to_return,
)
def forward_cuda(
self,
@@ -163,6 +171,8 @@ class TopKTopPSampler(nn.Module):
"PyTorch-native implementation."
)
return self.forward_native(logits, generators, k, p)
if self.use_fp64_gumbel:
return self.forward_native(logits, generators, k, p)
assert self.logprobs_mode not in ("processed_logits", "processed_logprobs"), (
"FlashInfer does not support returning logits/logprobs"
)
@@ -190,16 +200,16 @@ class TopKTopPSampler(nn.Module):
elif self.logprobs_mode == "processed_logprobs":
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
if len(generators) != logits.shape[0]:
if len(generators) != logits.shape[0] and not self.use_fp64_gumbel:
return compiled_random_sample(logits), logits_to_return
probs = logits.softmax(dim=-1, dtype=torch.float32)
q = torch.empty_like(probs)
q = empty_exponential_noise_like(probs, self.use_fp64_gumbel)
q.exponential_()
for i, generator in generators.items():
q[i].exponential_(generator=generator)
return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return
return sample_with_exponential_noise(probs, q), logits_to_return
def forward_hip(
self,
@@ -216,6 +226,8 @@ class TopKTopPSampler(nn.Module):
"falling back to PyTorch-native."
)
return self.forward_native(logits, generators, k, p)
if self.use_fp64_gumbel:
return self.forward_native(logits, generators, k, p)
assert self.logprobs_mode not in (
"processed_logits",
"processed_logprobs",
@@ -404,16 +416,33 @@ def apply_top_k_only(logits: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
return logits.masked_fill_(logits < top_k_mask, -float("inf"))
def empty_exponential_noise_like(
probs: torch.Tensor, use_fp64_gumbel: bool
) -> torch.Tensor:
dtype = torch.float64 if use_fp64_gumbel else probs.dtype
return torch.empty(probs.shape, dtype=dtype, device=probs.device)
def sample_with_exponential_noise(probs: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
if q.dtype == probs.dtype:
scores = probs.div_(q)
else:
scores = q.reciprocal_()
scores.mul_(probs)
return scores.argmax(dim=-1).view(-1)
def random_sample(
probs: torch.Tensor,
generators: dict[int, torch.Generator],
use_fp64_gumbel: bool = False,
) -> torch.Tensor:
"""Randomly sample from the probabilities.
We use this function instead of torch.multinomial because torch.multinomial
causes CPU-GPU synchronization.
"""
q = torch.empty_like(probs)
q = empty_exponential_noise_like(probs, use_fp64_gumbel)
# NOTE(woosuk): To batch-process the requests without their own seeds,
# which is the common case, we first assume that every request does
# not have its own seed. Then, we overwrite the values for the requests
@@ -425,7 +454,7 @@ def random_sample(
# one by one. Optimize this.
for i, generator in generators.items():
q[i].exponential_(generator=generator)
return probs.div_(q).argmax(dim=-1).view(-1)
return sample_with_exponential_noise(probs, q)
def flashinfer_sample(
+13 -2
View File
@@ -65,6 +65,7 @@ class RejectionSampler(nn.Module):
):
super().__init__()
self.sampler = sampler
self.use_fp64_gumbel = getattr(sampler, "use_fp64_gumbel", False)
logprobs_mode = self.sampler.logprobs_mode
self.is_processed_logprobs_mode = logprobs_mode.startswith("processed")
self.is_logits_logprobs_mode = logprobs_mode.endswith("logits")
@@ -176,6 +177,7 @@ class RejectionSampler(nn.Module):
sampling_metadata,
synthetic_mode=self.synthetic_mode,
synthetic_conditional_rates=self.synthetic_conditional_rates,
use_fp64_gumbel=self.use_fp64_gumbel,
)
logprobs_tensors = None
@@ -406,6 +408,7 @@ def rejection_sample(
sampling_metadata: SamplingMetadata,
synthetic_mode: bool = False,
synthetic_conditional_rates: torch.Tensor | None = None,
use_fp64_gumbel: bool = False,
) -> torch.Tensor:
assert draft_token_ids.ndim == 1
assert draft_probs is None or draft_probs.ndim == 2
@@ -480,6 +483,7 @@ def rejection_sample(
target_probs,
sampling_metadata,
device,
use_fp64_gumbel,
)
# Rejection sampling for random sampling requests.
@@ -669,13 +673,15 @@ def sample_recovered_tokens(
target_probs: torch.Tensor,
sampling_metadata: SamplingMetadata,
device: torch.device,
use_fp64_gumbel: bool = False,
) -> torch.Tensor:
# NOTE(woosuk): Create only one distribution for each request.
batch_size = len(num_draft_tokens)
vocab_size = target_probs.shape[-1]
q_dtype = torch.float64 if use_fp64_gumbel else torch.float32
q = torch.empty(
(batch_size, vocab_size),
dtype=torch.float32,
dtype=q_dtype,
device=device,
)
q.exponential_()
@@ -699,6 +705,7 @@ def sample_recovered_tokens(
vocab_size,
BLOCK_SIZE,
NO_DRAFT_PROBS=draft_probs is None,
USE_FP64_GUMBEL=use_fp64_gumbel,
)
return recovered_token_ids
@@ -861,6 +868,7 @@ def sample_recovered_tokens_kernel(
vocab_size,
BLOCK_SIZE: tl.constexpr,
NO_DRAFT_PROBS: tl.constexpr,
USE_FP64_GUMBEL: tl.constexpr,
):
req_idx = tl.program_id(0)
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
@@ -877,7 +885,10 @@ def sample_recovered_tokens_kernel(
if NO_DRAFT_PROBS:
draft_token_id = tl.load(draft_token_ids_ptr + token_idx)
max_val = float("-inf")
if USE_FP64_GUMBEL:
max_val = tl.full((), float("-inf"), tl.float64)
else:
max_val = tl.full((), float("-inf"), tl.float32)
recovered_id = 0
for v in range(0, vocab_size, BLOCK_SIZE):
vocab_offset = v + tl.arange(0, BLOCK_SIZE)
+7 -2
View File
@@ -58,11 +58,16 @@ class Sampler(nn.Module):
9. Return the final `SamplerOutput`.
"""
def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"):
def __init__(
self,
logprobs_mode: LogprobsMode = "raw_logprobs",
use_fp64_gumbel: bool = False,
):
super().__init__()
self.topk_topp_sampler = TopKTopPSampler(logprobs_mode)
self.topk_topp_sampler = TopKTopPSampler(logprobs_mode, use_fp64_gumbel)
self.pin_memory = is_pin_memory_available()
self.logprobs_mode = logprobs_mode
self.use_fp64_gumbel = use_fp64_gumbel
def forward(
self,
+11 -3
View File
@@ -32,6 +32,10 @@ from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from vllm.v1.kv_cache_interface import KVCacheConfig, UniformTypeKVCacheSpecs
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.topk_topp_sampler import (
empty_exponential_noise_like,
sample_with_exponential_noise,
)
from vllm.v1.sample.sampler import _SAMPLING_EPS
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.utils import (
@@ -113,6 +117,7 @@ class SpecDecodeBaseProposer:
self.use_local_argmax_reduction: bool = (
self.speculative_config.use_local_argmax_reduction
)
self.use_fp64_gumbel = vllm_config.model_config.use_fp64_gumbel
self.max_batch_size = vllm_config.scheduler_config.max_num_seqs
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
@@ -409,7 +414,9 @@ class SpecDecodeBaseProposer:
return logits.argmax(dim=-1), None
if sampling_metadata.all_greedy:
return logits.argmax(dim=-1), None
return compute_probs_and_sample_next_token(logits, sampling_metadata)
return compute_probs_and_sample_next_token(
logits, sampling_metadata, self.use_fp64_gumbel
)
def _sample_draft_tokens(
self,
@@ -1656,6 +1663,7 @@ class SpecDecodeBaseProposer:
def compute_probs_and_sample_next_token(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
use_fp64_gumbel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
if sampling_metadata.all_greedy:
# For greedy requests, draft_probs is not used in rejection sampling.
@@ -1682,11 +1690,11 @@ def compute_probs_and_sample_next_token(
# of the generated tokens after rejection sampling.
# TODO(woosuk): Consider seeds.
q = torch.empty_like(probs)
q = empty_exponential_noise_like(probs, use_fp64_gumbel)
q.exponential_()
# NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs
# will be used later for rejection sampling.
next_token_ids = probs.div(q).argmax(dim=-1).view(-1)
next_token_ids = sample_with_exponential_noise(probs.clone(), q)
if not sampling_metadata.all_random:
greedy_token_ids = probs.argmax(dim=-1)
next_token_ids = torch.where(is_greedy, greedy_token_ids, next_token_ids)
+4 -1
View File
@@ -504,7 +504,10 @@ class GPUModelRunner(
self.use_async_scheduling = self.scheduler_config.async_scheduling
# Sampler
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
self.sampler = Sampler(
logprobs_mode=self.model_config.logprobs_mode,
use_fp64_gumbel=self.model_config.use_fp64_gumbel,
)
self.eplb_state: EplbState | None = None
self._moe_model: MixtureOfExperts | None = None