mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[BUG] Fix FP64 Gumbel precision coverage (#43150)
Signed-off-by: tianyu-z <zhangtianyupro@gmail.com> Signed-off-by: Tianyu Zhang <53099276+tianyu-z@users.noreply.github.com> Co-authored-by: Nick Hill <nickhill123@gmail.com> Co-authored-by: OpenAI Codex <codex@openai.com>
This commit is contained in:
@@ -544,13 +544,15 @@ def native_sample_recovered_tokens(
|
||||
target_probs: torch.Tensor, # [num_tokens, vocab_size]
|
||||
sampling_metadata: SamplingMetadata,
|
||||
device: torch.device,
|
||||
use_fp64_gumbel: bool = False,
|
||||
) -> torch.Tensor:
|
||||
batch_size = len(num_draft_tokens)
|
||||
vocab_size = target_probs.shape[-1]
|
||||
q_dtype = torch.float64 if use_fp64_gumbel else torch.float32
|
||||
|
||||
q = torch.empty(
|
||||
(batch_size, vocab_size),
|
||||
dtype=torch.float32,
|
||||
dtype=q_dtype,
|
||||
device=device,
|
||||
)
|
||||
q.exponential_()
|
||||
@@ -935,6 +937,67 @@ def test_sample_recovered_tokens(
|
||||
assert torch.equal(recovered_token_ids, ref_recovered_token_ids)
|
||||
|
||||
|
||||
def test_sample_recovered_tokens_uses_fp64_exponential_race_when_requested():
|
||||
batch_size = 2
|
||||
vocab_size = 64
|
||||
max_spec_len = 2
|
||||
num_tokens = batch_size * max_spec_len
|
||||
|
||||
draft_probs = torch.rand(
|
||||
num_tokens,
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device=DEVICE_TYPE,
|
||||
)
|
||||
draft_probs = F.softmax(draft_probs, dim=-1)
|
||||
target_probs = torch.rand(
|
||||
num_tokens,
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device=DEVICE_TYPE,
|
||||
)
|
||||
target_probs = F.softmax(target_probs, dim=-1)
|
||||
draft_token_ids = torch.multinomial(draft_probs, num_samples=1).to(torch.int32)
|
||||
|
||||
generators = {
|
||||
i: torch.Generator(device=DEVICE_TYPE).manual_seed(i) for i in range(batch_size)
|
||||
}
|
||||
sampling_metadata = create_sampling_metadata(
|
||||
all_greedy=False,
|
||||
temperature=torch.ones(batch_size, dtype=torch.float32, device=DEVICE_TYPE),
|
||||
generators=generators,
|
||||
)
|
||||
spec_decode_metadata = create_spec_decode_metadata(
|
||||
draft_token_ids.reshape(batch_size, max_spec_len).tolist(),
|
||||
target_probs.log(),
|
||||
)
|
||||
|
||||
expected = native_sample_recovered_tokens(
|
||||
max_spec_len,
|
||||
spec_decode_metadata.num_draft_tokens,
|
||||
spec_decode_metadata.cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
sampling_metadata,
|
||||
device=torch.device(DEVICE_TYPE),
|
||||
use_fp64_gumbel=True,
|
||||
)
|
||||
actual = sample_recovered_tokens(
|
||||
max_spec_len,
|
||||
spec_decode_metadata.num_draft_tokens,
|
||||
spec_decode_metadata.cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
sampling_metadata,
|
||||
device=torch.device(DEVICE_TYPE),
|
||||
use_fp64_gumbel=True,
|
||||
)
|
||||
|
||||
assert torch.equal(actual, expected)
|
||||
|
||||
|
||||
########################### Tests for Synthetic Rejection Sampling #########
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,12 @@ from torch import Generator
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_pytorch
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import (
|
||||
apply_top_k_top_p_pytorch,
|
||||
random_sample,
|
||||
)
|
||||
from vllm.v1.sample.sampler import Sampler
|
||||
|
||||
DEVICE_TYPE = current_platform.device_type
|
||||
|
||||
@@ -38,6 +43,10 @@ def _flashinfer_topk_topp_supported() -> bool:
|
||||
FLASHINFER_TOPK_TOPP_SUPPORTED = _flashinfer_topk_topp_supported()
|
||||
|
||||
|
||||
def _seed_default_generator(seed: int) -> None:
|
||||
set_random_seed(seed)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_default_device():
|
||||
"""
|
||||
@@ -49,6 +58,35 @@ def reset_default_device():
|
||||
torch.set_default_device(original_device)
|
||||
|
||||
|
||||
def test_sampler_threads_fp64_gumbel_to_topk_topp_sampler():
|
||||
sampler = Sampler(use_fp64_gumbel=True)
|
||||
|
||||
assert sampler.topk_topp_sampler.use_fp64_gumbel
|
||||
|
||||
|
||||
def test_random_sample_uses_fp64_exponential_race_when_requested():
|
||||
torch.set_default_device(DEVICE_TYPE)
|
||||
probs = torch.tensor(
|
||||
[
|
||||
[0.70, 0.20, 0.10],
|
||||
[0.05, 0.15, 0.80],
|
||||
[0.25, 0.25, 0.50],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=DEVICE_TYPE,
|
||||
)
|
||||
|
||||
_seed_default_generator(12345)
|
||||
q = torch.empty(probs.shape, dtype=torch.float64, device=probs.device)
|
||||
q.exponential_()
|
||||
expected = q.reciprocal_().mul_(probs).argmax(dim=-1).view(-1)
|
||||
|
||||
_seed_default_generator(12345)
|
||||
actual = random_sample(probs.clone(), {}, use_fp64_gumbel=True)
|
||||
|
||||
assert torch.equal(actual, expected)
|
||||
|
||||
|
||||
def test_topk_impl_equivalence():
|
||||
torch.set_default_device(DEVICE_TYPE)
|
||||
generator = Generator(device=DEVICE_TYPE).manual_seed(33)
|
||||
|
||||
@@ -1034,7 +1034,8 @@ def test_propose_stores_probabilistic_draft_probs(monkeypatch):
|
||||
proposer.model = model_mock
|
||||
proposer._draft_attn_layer_names = {"layer.0"}
|
||||
|
||||
def fake_compute_probs(logits, sampling_metadata):
|
||||
def fake_compute_probs(logits, sampling_metadata, use_fp64_gumbel):
|
||||
assert use_fp64_gumbel == proposer.use_fp64_gumbel
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
return probs.argmax(dim=-1), probs
|
||||
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.llm_base_proposer import (
|
||||
compute_probs_and_sample_next_token,
|
||||
)
|
||||
|
||||
DEVICE_TYPE = current_platform.device_type
|
||||
|
||||
|
||||
def _seed_default_generator(seed: int) -> None:
|
||||
set_random_seed(seed)
|
||||
|
||||
|
||||
def _make_sampling_metadata(batch_size: int) -> SamplingMetadata:
|
||||
return SamplingMetadata(
|
||||
temperature=torch.ones(batch_size, dtype=torch.float32, device=DEVICE_TYPE),
|
||||
all_greedy=False,
|
||||
all_random=True,
|
||||
top_p=None,
|
||||
top_k=None,
|
||||
generators={},
|
||||
max_num_logprobs=None,
|
||||
no_penalties=True,
|
||||
prompt_token_ids=None,
|
||||
frequency_penalties=torch.empty(0, device=DEVICE_TYPE),
|
||||
presence_penalties=torch.empty(0, device=DEVICE_TYPE),
|
||||
repetition_penalties=torch.empty(0, device=DEVICE_TYPE),
|
||||
output_token_ids=[[] for _ in range(batch_size)],
|
||||
spec_token_ids=[[] for _ in range(batch_size)],
|
||||
allowed_token_ids_mask=None,
|
||||
bad_words_token_ids={},
|
||||
logitsprocs=LogitsProcessors(),
|
||||
)
|
||||
|
||||
|
||||
def test_compute_probs_and_sample_next_token_uses_fp64_exponential_race():
|
||||
batch_size = 4
|
||||
vocab_size = 32
|
||||
generator = torch.Generator(device=DEVICE_TYPE).manual_seed(11)
|
||||
logits = torch.randn(
|
||||
batch_size,
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device=DEVICE_TYPE,
|
||||
generator=generator,
|
||||
)
|
||||
metadata = _make_sampling_metadata(batch_size)
|
||||
|
||||
_seed_default_generator(12345)
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
q = torch.empty(probs.shape, dtype=torch.float64, device=probs.device)
|
||||
q.exponential_()
|
||||
expected_ids = q.reciprocal_().mul_(probs).argmax(dim=-1).view(-1)
|
||||
|
||||
_seed_default_generator(12345)
|
||||
actual_ids, actual_probs = compute_probs_and_sample_next_token(
|
||||
logits.clone(),
|
||||
metadata,
|
||||
use_fp64_gumbel=True,
|
||||
)
|
||||
|
||||
assert torch.equal(actual_ids, expected_ids)
|
||||
assert torch.allclose(actual_probs, probs)
|
||||
Reference in New Issue
Block a user