[BUG] Fix FP64 Gumbel precision coverage (#43150)

Signed-off-by: tianyu-z <zhangtianyupro@gmail.com>
Signed-off-by: Tianyu Zhang <53099276+tianyu-z@users.noreply.github.com>
Co-authored-by: Nick Hill <nickhill123@gmail.com>
Co-authored-by: OpenAI Codex <codex@openai.com>
This commit is contained in:
Tianyu Zhang
2026-06-05 04:04:14 -07:00
committed by GitHub
parent 8a83e6f2d7
commit 7fe7800fa4
11 changed files with 391 additions and 21 deletions
+64 -1
View File
@@ -544,13 +544,15 @@ def native_sample_recovered_tokens(
target_probs: torch.Tensor, # [num_tokens, vocab_size]
sampling_metadata: SamplingMetadata,
device: torch.device,
use_fp64_gumbel: bool = False,
) -> torch.Tensor:
batch_size = len(num_draft_tokens)
vocab_size = target_probs.shape[-1]
q_dtype = torch.float64 if use_fp64_gumbel else torch.float32
q = torch.empty(
(batch_size, vocab_size),
dtype=torch.float32,
dtype=q_dtype,
device=device,
)
q.exponential_()
@@ -935,6 +937,67 @@ def test_sample_recovered_tokens(
assert torch.equal(recovered_token_ids, ref_recovered_token_ids)
def test_sample_recovered_tokens_uses_fp64_exponential_race_when_requested():
batch_size = 2
vocab_size = 64
max_spec_len = 2
num_tokens = batch_size * max_spec_len
draft_probs = torch.rand(
num_tokens,
vocab_size,
dtype=torch.float32,
device=DEVICE_TYPE,
)
draft_probs = F.softmax(draft_probs, dim=-1)
target_probs = torch.rand(
num_tokens,
vocab_size,
dtype=torch.float32,
device=DEVICE_TYPE,
)
target_probs = F.softmax(target_probs, dim=-1)
draft_token_ids = torch.multinomial(draft_probs, num_samples=1).to(torch.int32)
generators = {
i: torch.Generator(device=DEVICE_TYPE).manual_seed(i) for i in range(batch_size)
}
sampling_metadata = create_sampling_metadata(
all_greedy=False,
temperature=torch.ones(batch_size, dtype=torch.float32, device=DEVICE_TYPE),
generators=generators,
)
spec_decode_metadata = create_spec_decode_metadata(
draft_token_ids.reshape(batch_size, max_spec_len).tolist(),
target_probs.log(),
)
expected = native_sample_recovered_tokens(
max_spec_len,
spec_decode_metadata.num_draft_tokens,
spec_decode_metadata.cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
sampling_metadata,
device=torch.device(DEVICE_TYPE),
use_fp64_gumbel=True,
)
actual = sample_recovered_tokens(
max_spec_len,
spec_decode_metadata.num_draft_tokens,
spec_decode_metadata.cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
sampling_metadata,
device=torch.device(DEVICE_TYPE),
use_fp64_gumbel=True,
)
assert torch.equal(actual, expected)
########################### Tests for Synthetic Rejection Sampling #########
+39 -1
View File
@@ -6,7 +6,12 @@ from torch import Generator
from vllm.platforms import current_platform
from vllm.triton_utils import HAS_TRITON
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_pytorch
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.sample.ops.topk_topp_sampler import (
apply_top_k_top_p_pytorch,
random_sample,
)
from vllm.v1.sample.sampler import Sampler
DEVICE_TYPE = current_platform.device_type
@@ -38,6 +43,10 @@ def _flashinfer_topk_topp_supported() -> bool:
FLASHINFER_TOPK_TOPP_SUPPORTED = _flashinfer_topk_topp_supported()
def _seed_default_generator(seed: int) -> None:
set_random_seed(seed)
@pytest.fixture(autouse=True)
def reset_default_device():
"""
@@ -49,6 +58,35 @@ def reset_default_device():
torch.set_default_device(original_device)
def test_sampler_threads_fp64_gumbel_to_topk_topp_sampler():
sampler = Sampler(use_fp64_gumbel=True)
assert sampler.topk_topp_sampler.use_fp64_gumbel
def test_random_sample_uses_fp64_exponential_race_when_requested():
torch.set_default_device(DEVICE_TYPE)
probs = torch.tensor(
[
[0.70, 0.20, 0.10],
[0.05, 0.15, 0.80],
[0.25, 0.25, 0.50],
],
dtype=torch.float32,
device=DEVICE_TYPE,
)
_seed_default_generator(12345)
q = torch.empty(probs.shape, dtype=torch.float64, device=probs.device)
q.exponential_()
expected = q.reciprocal_().mul_(probs).argmax(dim=-1).view(-1)
_seed_default_generator(12345)
actual = random_sample(probs.clone(), {}, use_fp64_gumbel=True)
assert torch.equal(actual, expected)
def test_topk_impl_equivalence():
torch.set_default_device(DEVICE_TYPE)
generator = Generator(device=DEVICE_TYPE).manual_seed(33)
+2 -1
View File
@@ -1034,7 +1034,8 @@ def test_propose_stores_probabilistic_draft_probs(monkeypatch):
proposer.model = model_mock
proposer._draft_attn_layer_names = {"layer.0"}
def fake_compute_probs(logits, sampling_metadata):
def fake_compute_probs(logits, sampling_metadata, use_fp64_gumbel):
assert use_fp64_gumbel == proposer.use_fp64_gumbel
probs = torch.softmax(logits, dim=-1)
return probs.argmax(dim=-1), probs
@@ -0,0 +1,70 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.sample.logits_processor import LogitsProcessors
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.llm_base_proposer import (
compute_probs_and_sample_next_token,
)
DEVICE_TYPE = current_platform.device_type
def _seed_default_generator(seed: int) -> None:
set_random_seed(seed)
def _make_sampling_metadata(batch_size: int) -> SamplingMetadata:
return SamplingMetadata(
temperature=torch.ones(batch_size, dtype=torch.float32, device=DEVICE_TYPE),
all_greedy=False,
all_random=True,
top_p=None,
top_k=None,
generators={},
max_num_logprobs=None,
no_penalties=True,
prompt_token_ids=None,
frequency_penalties=torch.empty(0, device=DEVICE_TYPE),
presence_penalties=torch.empty(0, device=DEVICE_TYPE),
repetition_penalties=torch.empty(0, device=DEVICE_TYPE),
output_token_ids=[[] for _ in range(batch_size)],
spec_token_ids=[[] for _ in range(batch_size)],
allowed_token_ids_mask=None,
bad_words_token_ids={},
logitsprocs=LogitsProcessors(),
)
def test_compute_probs_and_sample_next_token_uses_fp64_exponential_race():
batch_size = 4
vocab_size = 32
generator = torch.Generator(device=DEVICE_TYPE).manual_seed(11)
logits = torch.randn(
batch_size,
vocab_size,
dtype=torch.float32,
device=DEVICE_TYPE,
generator=generator,
)
metadata = _make_sampling_metadata(batch_size)
_seed_default_generator(12345)
probs = logits.softmax(dim=-1, dtype=torch.float32)
q = torch.empty(probs.shape, dtype=torch.float64, device=probs.device)
q.exponential_()
expected_ids = q.reciprocal_().mul_(probs).argmax(dim=-1).view(-1)
_seed_default_generator(12345)
actual_ids, actual_probs = compute_probs_and_sample_next_token(
logits.clone(),
metadata,
use_fp64_gumbel=True,
)
assert torch.equal(actual_ids, expected_ids)
assert torch.allclose(actual_probs, probs)