mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[BUG] Fix FP64 Gumbel precision coverage (#43150)
Signed-off-by: tianyu-z <zhangtianyupro@gmail.com> Signed-off-by: Tianyu Zhang <53099276+tianyu-z@users.noreply.github.com> Co-authored-by: Nick Hill <nickhill123@gmail.com> Co-authored-by: OpenAI Codex <codex@openai.com>
This commit is contained in:
@@ -544,13 +544,15 @@ def native_sample_recovered_tokens(
|
||||
target_probs: torch.Tensor, # [num_tokens, vocab_size]
|
||||
sampling_metadata: SamplingMetadata,
|
||||
device: torch.device,
|
||||
use_fp64_gumbel: bool = False,
|
||||
) -> torch.Tensor:
|
||||
batch_size = len(num_draft_tokens)
|
||||
vocab_size = target_probs.shape[-1]
|
||||
q_dtype = torch.float64 if use_fp64_gumbel else torch.float32
|
||||
|
||||
q = torch.empty(
|
||||
(batch_size, vocab_size),
|
||||
dtype=torch.float32,
|
||||
dtype=q_dtype,
|
||||
device=device,
|
||||
)
|
||||
q.exponential_()
|
||||
@@ -935,6 +937,67 @@ def test_sample_recovered_tokens(
|
||||
assert torch.equal(recovered_token_ids, ref_recovered_token_ids)
|
||||
|
||||
|
||||
def test_sample_recovered_tokens_uses_fp64_exponential_race_when_requested():
|
||||
batch_size = 2
|
||||
vocab_size = 64
|
||||
max_spec_len = 2
|
||||
num_tokens = batch_size * max_spec_len
|
||||
|
||||
draft_probs = torch.rand(
|
||||
num_tokens,
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device=DEVICE_TYPE,
|
||||
)
|
||||
draft_probs = F.softmax(draft_probs, dim=-1)
|
||||
target_probs = torch.rand(
|
||||
num_tokens,
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device=DEVICE_TYPE,
|
||||
)
|
||||
target_probs = F.softmax(target_probs, dim=-1)
|
||||
draft_token_ids = torch.multinomial(draft_probs, num_samples=1).to(torch.int32)
|
||||
|
||||
generators = {
|
||||
i: torch.Generator(device=DEVICE_TYPE).manual_seed(i) for i in range(batch_size)
|
||||
}
|
||||
sampling_metadata = create_sampling_metadata(
|
||||
all_greedy=False,
|
||||
temperature=torch.ones(batch_size, dtype=torch.float32, device=DEVICE_TYPE),
|
||||
generators=generators,
|
||||
)
|
||||
spec_decode_metadata = create_spec_decode_metadata(
|
||||
draft_token_ids.reshape(batch_size, max_spec_len).tolist(),
|
||||
target_probs.log(),
|
||||
)
|
||||
|
||||
expected = native_sample_recovered_tokens(
|
||||
max_spec_len,
|
||||
spec_decode_metadata.num_draft_tokens,
|
||||
spec_decode_metadata.cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
sampling_metadata,
|
||||
device=torch.device(DEVICE_TYPE),
|
||||
use_fp64_gumbel=True,
|
||||
)
|
||||
actual = sample_recovered_tokens(
|
||||
max_spec_len,
|
||||
spec_decode_metadata.num_draft_tokens,
|
||||
spec_decode_metadata.cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
sampling_metadata,
|
||||
device=torch.device(DEVICE_TYPE),
|
||||
use_fp64_gumbel=True,
|
||||
)
|
||||
|
||||
assert torch.equal(actual, expected)
|
||||
|
||||
|
||||
########################### Tests for Synthetic Rejection Sampling #########
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,12 @@ from torch import Generator
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_pytorch
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import (
|
||||
apply_top_k_top_p_pytorch,
|
||||
random_sample,
|
||||
)
|
||||
from vllm.v1.sample.sampler import Sampler
|
||||
|
||||
DEVICE_TYPE = current_platform.device_type
|
||||
|
||||
@@ -38,6 +43,10 @@ def _flashinfer_topk_topp_supported() -> bool:
|
||||
FLASHINFER_TOPK_TOPP_SUPPORTED = _flashinfer_topk_topp_supported()
|
||||
|
||||
|
||||
def _seed_default_generator(seed: int) -> None:
|
||||
set_random_seed(seed)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_default_device():
|
||||
"""
|
||||
@@ -49,6 +58,35 @@ def reset_default_device():
|
||||
torch.set_default_device(original_device)
|
||||
|
||||
|
||||
def test_sampler_threads_fp64_gumbel_to_topk_topp_sampler():
|
||||
sampler = Sampler(use_fp64_gumbel=True)
|
||||
|
||||
assert sampler.topk_topp_sampler.use_fp64_gumbel
|
||||
|
||||
|
||||
def test_random_sample_uses_fp64_exponential_race_when_requested():
|
||||
torch.set_default_device(DEVICE_TYPE)
|
||||
probs = torch.tensor(
|
||||
[
|
||||
[0.70, 0.20, 0.10],
|
||||
[0.05, 0.15, 0.80],
|
||||
[0.25, 0.25, 0.50],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=DEVICE_TYPE,
|
||||
)
|
||||
|
||||
_seed_default_generator(12345)
|
||||
q = torch.empty(probs.shape, dtype=torch.float64, device=probs.device)
|
||||
q.exponential_()
|
||||
expected = q.reciprocal_().mul_(probs).argmax(dim=-1).view(-1)
|
||||
|
||||
_seed_default_generator(12345)
|
||||
actual = random_sample(probs.clone(), {}, use_fp64_gumbel=True)
|
||||
|
||||
assert torch.equal(actual, expected)
|
||||
|
||||
|
||||
def test_topk_impl_equivalence():
|
||||
torch.set_default_device(DEVICE_TYPE)
|
||||
generator = Generator(device=DEVICE_TYPE).manual_seed(33)
|
||||
|
||||
@@ -1034,7 +1034,8 @@ def test_propose_stores_probabilistic_draft_probs(monkeypatch):
|
||||
proposer.model = model_mock
|
||||
proposer._draft_attn_layer_names = {"layer.0"}
|
||||
|
||||
def fake_compute_probs(logits, sampling_metadata):
|
||||
def fake_compute_probs(logits, sampling_metadata, use_fp64_gumbel):
|
||||
assert use_fp64_gumbel == proposer.use_fp64_gumbel
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
return probs.argmax(dim=-1), probs
|
||||
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.llm_base_proposer import (
|
||||
compute_probs_and_sample_next_token,
|
||||
)
|
||||
|
||||
DEVICE_TYPE = current_platform.device_type
|
||||
|
||||
|
||||
def _seed_default_generator(seed: int) -> None:
|
||||
set_random_seed(seed)
|
||||
|
||||
|
||||
def _make_sampling_metadata(batch_size: int) -> SamplingMetadata:
|
||||
return SamplingMetadata(
|
||||
temperature=torch.ones(batch_size, dtype=torch.float32, device=DEVICE_TYPE),
|
||||
all_greedy=False,
|
||||
all_random=True,
|
||||
top_p=None,
|
||||
top_k=None,
|
||||
generators={},
|
||||
max_num_logprobs=None,
|
||||
no_penalties=True,
|
||||
prompt_token_ids=None,
|
||||
frequency_penalties=torch.empty(0, device=DEVICE_TYPE),
|
||||
presence_penalties=torch.empty(0, device=DEVICE_TYPE),
|
||||
repetition_penalties=torch.empty(0, device=DEVICE_TYPE),
|
||||
output_token_ids=[[] for _ in range(batch_size)],
|
||||
spec_token_ids=[[] for _ in range(batch_size)],
|
||||
allowed_token_ids_mask=None,
|
||||
bad_words_token_ids={},
|
||||
logitsprocs=LogitsProcessors(),
|
||||
)
|
||||
|
||||
|
||||
def test_compute_probs_and_sample_next_token_uses_fp64_exponential_race():
|
||||
batch_size = 4
|
||||
vocab_size = 32
|
||||
generator = torch.Generator(device=DEVICE_TYPE).manual_seed(11)
|
||||
logits = torch.randn(
|
||||
batch_size,
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device=DEVICE_TYPE,
|
||||
generator=generator,
|
||||
)
|
||||
metadata = _make_sampling_metadata(batch_size)
|
||||
|
||||
_seed_default_generator(12345)
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
q = torch.empty(probs.shape, dtype=torch.float64, device=probs.device)
|
||||
q.exponential_()
|
||||
expected_ids = q.reciprocal_().mul_(probs).argmax(dim=-1).view(-1)
|
||||
|
||||
_seed_default_generator(12345)
|
||||
actual_ids, actual_probs = compute_probs_and_sample_next_token(
|
||||
logits.clone(),
|
||||
metadata,
|
||||
use_fp64_gumbel=True,
|
||||
)
|
||||
|
||||
assert torch.equal(actual_ids, expected_ids)
|
||||
assert torch.allclose(actual_probs, probs)
|
||||
@@ -0,0 +1,141 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""CUDA proof for fp32 exponential-race tail truncation.
|
||||
|
||||
This script is intentionally not a unit test. It is a reproducible, GPU-only
|
||||
statistical proof for the hidden Gumbel-max idiom:
|
||||
|
||||
q.exponential_()
|
||||
sample = (probs / q).argmax()
|
||||
|
||||
For q ~ Exp(1), this is equivalent to argmax(log(probs) + Gumbel). On CUDA,
|
||||
fp32 exponential samples inherit a 24-bit uniform lower-tail cutoff, so very
|
||||
small q values are impossible. The many-tail experiment below chooses a case
|
||||
where a correct sampler should select a low-probability tail token dozens of
|
||||
times, while fp32 q cannot select one.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _seed(seed: int) -> None:
|
||||
torch.manual_seed(seed)
|
||||
|
||||
|
||||
def measure_exponential_lower_tail(
|
||||
*,
|
||||
device: torch.device,
|
||||
samples: int,
|
||||
chunk_size: int,
|
||||
seed: int,
|
||||
) -> None:
|
||||
threshold = 2.0**-24
|
||||
print(f"lower-tail threshold: {threshold:.18e}")
|
||||
for dtype in (torch.float32, torch.float64):
|
||||
_seed(seed)
|
||||
count_below = 0
|
||||
min_q = float("inf")
|
||||
max_q = 0.0
|
||||
start = time.perf_counter()
|
||||
remaining = samples
|
||||
while remaining > 0:
|
||||
n = min(chunk_size, remaining)
|
||||
q = torch.empty((n,), dtype=dtype, device=device)
|
||||
q.exponential_()
|
||||
count_below += int((q < threshold).sum().item())
|
||||
min_q = min(min_q, float(q.min().item()))
|
||||
max_q = max(max_q, float(q.max().item()))
|
||||
remaining -= n
|
||||
torch.accelerator.synchronize()
|
||||
elapsed = time.perf_counter() - start
|
||||
print(
|
||||
f"{dtype}: samples={samples} count(q < 2^-24)={count_below} "
|
||||
f"min={min_q:.18e} max={max_q:.6f} elapsed={elapsed:.2f}s"
|
||||
)
|
||||
|
||||
|
||||
def run_many_tail_race(
|
||||
*,
|
||||
device: torch.device,
|
||||
trials: int,
|
||||
num_tail_tokens: int,
|
||||
gap: float,
|
||||
chunk_trials: int,
|
||||
seed: int,
|
||||
) -> None:
|
||||
p_tail = math.exp(-gap)
|
||||
expected_tail_hits = (
|
||||
trials * (num_tail_tokens * p_tail) / (1.0 + num_tail_tokens * p_tail)
|
||||
)
|
||||
print(
|
||||
"many-tail race: "
|
||||
f"trials={trials} num_tail_tokens={num_tail_tokens} gap={gap} "
|
||||
f"expected_tail_hits={expected_tail_hits:.4f}"
|
||||
)
|
||||
|
||||
for dtype in (torch.float32, torch.float64):
|
||||
_seed(seed)
|
||||
hits = 0
|
||||
p0 = torch.tensor(1.0, dtype=dtype, device=device)
|
||||
pt = torch.tensor(p_tail, dtype=dtype, device=device)
|
||||
start = time.perf_counter()
|
||||
remaining = trials
|
||||
while remaining > 0:
|
||||
batch = min(chunk_trials, remaining)
|
||||
q0 = torch.empty((batch,), dtype=dtype, device=device)
|
||||
q0.exponential_()
|
||||
qt = torch.empty((batch, num_tail_tokens), dtype=dtype, device=device)
|
||||
qt.exponential_()
|
||||
head_score = p0 / q0
|
||||
tail_score = (pt / qt).amax(dim=-1)
|
||||
hits += int((tail_score > head_score).sum().item())
|
||||
remaining -= batch
|
||||
torch.accelerator.synchronize()
|
||||
elapsed = time.perf_counter() - start
|
||||
print(f"{dtype}: tail_hits={hits} elapsed={elapsed:.2f}s")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--lower-tail-samples", type=int, default=200_000_000)
|
||||
parser.add_argument("--lower-tail-chunk-size", type=int, default=10_000_000)
|
||||
parser.add_argument("--race-trials", type=int, default=100_000)
|
||||
parser.add_argument("--race-tail-tokens", type=int, default=262_144)
|
||||
parser.add_argument("--race-gap", type=float, default=20.5)
|
||||
parser.add_argument("--race-chunk-trials", type=int, default=64)
|
||||
parser.add_argument("--seed", type=int, default=2026)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not torch.accelerator.is_available():
|
||||
raise RuntimeError("CUDA is required for this proof.")
|
||||
|
||||
device = torch.accelerator.current_accelerator()
|
||||
if device.type != "cuda":
|
||||
raise RuntimeError("CUDA is required for this proof.")
|
||||
|
||||
print(f"torch={torch.__version__} cuda={torch.version.cuda}")
|
||||
print(f"device={device}")
|
||||
measure_exponential_lower_tail(
|
||||
device=device,
|
||||
samples=args.lower_tail_samples,
|
||||
chunk_size=args.lower_tail_chunk_size,
|
||||
seed=args.seed,
|
||||
)
|
||||
run_many_tail_race(
|
||||
device=device,
|
||||
trials=args.race_trials,
|
||||
num_tail_tokens=args.race_tail_tokens,
|
||||
gap=args.race_gap,
|
||||
chunk_trials=args.race_chunk_trials,
|
||||
seed=args.seed,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -235,9 +235,10 @@ class ModelConfig:
|
||||
temperature and top_k/top_p.
|
||||
"""
|
||||
use_fp64_gumbel: bool = False
|
||||
"""Whether to use FP64 (instead of FP32) for the Gumbel noise used by the
|
||||
sampler. FP64 reduces the chance of ties in Gumbel-max sampling at the cost
|
||||
of significantly lower kernel throughput on most GPUs."""
|
||||
"""Whether to use FP64 (instead of FP32) random noise for Gumbel-max and
|
||||
equivalent exponential-race sampling. FP64 preserves lower-tail sampling
|
||||
events that fp32 uniform/exponential draws can truncate, at the cost of
|
||||
significantly lower throughput on most GPUs."""
|
||||
disable_sliding_window: bool = False
|
||||
"""Whether to disable sliding window. If True, we will disable the sliding
|
||||
window functionality of the model, capping to sliding window size. If the
|
||||
|
||||
@@ -75,9 +75,14 @@ class TopKTopPSampler(nn.Module):
|
||||
Implementations may update the logits tensor in-place.
|
||||
"""
|
||||
|
||||
def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None:
|
||||
def __init__(
|
||||
self,
|
||||
logprobs_mode: LogprobsMode = "raw_logprobs",
|
||||
use_fp64_gumbel: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.logprobs_mode = logprobs_mode
|
||||
self.use_fp64_gumbel = use_fp64_gumbel
|
||||
if current_platform.is_cuda():
|
||||
# FlashInfer doesn't expose post-top-k/top-p logits/logprobs,
|
||||
# so it can't be used when the configured mode requires them.
|
||||
@@ -142,7 +147,10 @@ class TopKTopPSampler(nn.Module):
|
||||
elif self.logprobs_mode == "processed_logprobs":
|
||||
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
return random_sample(probs, generators), logits_to_return
|
||||
return (
|
||||
random_sample(probs, generators, self.use_fp64_gumbel),
|
||||
logits_to_return,
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
@@ -163,6 +171,8 @@ class TopKTopPSampler(nn.Module):
|
||||
"PyTorch-native implementation."
|
||||
)
|
||||
return self.forward_native(logits, generators, k, p)
|
||||
if self.use_fp64_gumbel:
|
||||
return self.forward_native(logits, generators, k, p)
|
||||
assert self.logprobs_mode not in ("processed_logits", "processed_logprobs"), (
|
||||
"FlashInfer does not support returning logits/logprobs"
|
||||
)
|
||||
@@ -190,16 +200,16 @@ class TopKTopPSampler(nn.Module):
|
||||
elif self.logprobs_mode == "processed_logprobs":
|
||||
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
|
||||
|
||||
if len(generators) != logits.shape[0]:
|
||||
if len(generators) != logits.shape[0] and not self.use_fp64_gumbel:
|
||||
return compiled_random_sample(logits), logits_to_return
|
||||
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
q = torch.empty_like(probs)
|
||||
q = empty_exponential_noise_like(probs, self.use_fp64_gumbel)
|
||||
q.exponential_()
|
||||
for i, generator in generators.items():
|
||||
q[i].exponential_(generator=generator)
|
||||
|
||||
return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return
|
||||
return sample_with_exponential_noise(probs, q), logits_to_return
|
||||
|
||||
def forward_hip(
|
||||
self,
|
||||
@@ -216,6 +226,8 @@ class TopKTopPSampler(nn.Module):
|
||||
"falling back to PyTorch-native."
|
||||
)
|
||||
return self.forward_native(logits, generators, k, p)
|
||||
if self.use_fp64_gumbel:
|
||||
return self.forward_native(logits, generators, k, p)
|
||||
assert self.logprobs_mode not in (
|
||||
"processed_logits",
|
||||
"processed_logprobs",
|
||||
@@ -404,16 +416,33 @@ def apply_top_k_only(logits: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
|
||||
return logits.masked_fill_(logits < top_k_mask, -float("inf"))
|
||||
|
||||
|
||||
def empty_exponential_noise_like(
|
||||
probs: torch.Tensor, use_fp64_gumbel: bool
|
||||
) -> torch.Tensor:
|
||||
dtype = torch.float64 if use_fp64_gumbel else probs.dtype
|
||||
return torch.empty(probs.shape, dtype=dtype, device=probs.device)
|
||||
|
||||
|
||||
def sample_with_exponential_noise(probs: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
|
||||
if q.dtype == probs.dtype:
|
||||
scores = probs.div_(q)
|
||||
else:
|
||||
scores = q.reciprocal_()
|
||||
scores.mul_(probs)
|
||||
return scores.argmax(dim=-1).view(-1)
|
||||
|
||||
|
||||
def random_sample(
|
||||
probs: torch.Tensor,
|
||||
generators: dict[int, torch.Generator],
|
||||
use_fp64_gumbel: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Randomly sample from the probabilities.
|
||||
|
||||
We use this function instead of torch.multinomial because torch.multinomial
|
||||
causes CPU-GPU synchronization.
|
||||
"""
|
||||
q = torch.empty_like(probs)
|
||||
q = empty_exponential_noise_like(probs, use_fp64_gumbel)
|
||||
# NOTE(woosuk): To batch-process the requests without their own seeds,
|
||||
# which is the common case, we first assume that every request does
|
||||
# not have its own seed. Then, we overwrite the values for the requests
|
||||
@@ -425,7 +454,7 @@ def random_sample(
|
||||
# one by one. Optimize this.
|
||||
for i, generator in generators.items():
|
||||
q[i].exponential_(generator=generator)
|
||||
return probs.div_(q).argmax(dim=-1).view(-1)
|
||||
return sample_with_exponential_noise(probs, q)
|
||||
|
||||
|
||||
def flashinfer_sample(
|
||||
|
||||
@@ -65,6 +65,7 @@ class RejectionSampler(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.sampler = sampler
|
||||
self.use_fp64_gumbel = getattr(sampler, "use_fp64_gumbel", False)
|
||||
logprobs_mode = self.sampler.logprobs_mode
|
||||
self.is_processed_logprobs_mode = logprobs_mode.startswith("processed")
|
||||
self.is_logits_logprobs_mode = logprobs_mode.endswith("logits")
|
||||
@@ -176,6 +177,7 @@ class RejectionSampler(nn.Module):
|
||||
sampling_metadata,
|
||||
synthetic_mode=self.synthetic_mode,
|
||||
synthetic_conditional_rates=self.synthetic_conditional_rates,
|
||||
use_fp64_gumbel=self.use_fp64_gumbel,
|
||||
)
|
||||
|
||||
logprobs_tensors = None
|
||||
@@ -406,6 +408,7 @@ def rejection_sample(
|
||||
sampling_metadata: SamplingMetadata,
|
||||
synthetic_mode: bool = False,
|
||||
synthetic_conditional_rates: torch.Tensor | None = None,
|
||||
use_fp64_gumbel: bool = False,
|
||||
) -> torch.Tensor:
|
||||
assert draft_token_ids.ndim == 1
|
||||
assert draft_probs is None or draft_probs.ndim == 2
|
||||
@@ -480,6 +483,7 @@ def rejection_sample(
|
||||
target_probs,
|
||||
sampling_metadata,
|
||||
device,
|
||||
use_fp64_gumbel,
|
||||
)
|
||||
|
||||
# Rejection sampling for random sampling requests.
|
||||
@@ -669,13 +673,15 @@ def sample_recovered_tokens(
|
||||
target_probs: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
device: torch.device,
|
||||
use_fp64_gumbel: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# NOTE(woosuk): Create only one distribution for each request.
|
||||
batch_size = len(num_draft_tokens)
|
||||
vocab_size = target_probs.shape[-1]
|
||||
q_dtype = torch.float64 if use_fp64_gumbel else torch.float32
|
||||
q = torch.empty(
|
||||
(batch_size, vocab_size),
|
||||
dtype=torch.float32,
|
||||
dtype=q_dtype,
|
||||
device=device,
|
||||
)
|
||||
q.exponential_()
|
||||
@@ -699,6 +705,7 @@ def sample_recovered_tokens(
|
||||
vocab_size,
|
||||
BLOCK_SIZE,
|
||||
NO_DRAFT_PROBS=draft_probs is None,
|
||||
USE_FP64_GUMBEL=use_fp64_gumbel,
|
||||
)
|
||||
return recovered_token_ids
|
||||
|
||||
@@ -861,6 +868,7 @@ def sample_recovered_tokens_kernel(
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
NO_DRAFT_PROBS: tl.constexpr,
|
||||
USE_FP64_GUMBEL: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
|
||||
@@ -877,7 +885,10 @@ def sample_recovered_tokens_kernel(
|
||||
if NO_DRAFT_PROBS:
|
||||
draft_token_id = tl.load(draft_token_ids_ptr + token_idx)
|
||||
|
||||
max_val = float("-inf")
|
||||
if USE_FP64_GUMBEL:
|
||||
max_val = tl.full((), float("-inf"), tl.float64)
|
||||
else:
|
||||
max_val = tl.full((), float("-inf"), tl.float32)
|
||||
recovered_id = 0
|
||||
for v in range(0, vocab_size, BLOCK_SIZE):
|
||||
vocab_offset = v + tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
@@ -58,11 +58,16 @@ class Sampler(nn.Module):
|
||||
9. Return the final `SamplerOutput`.
|
||||
"""
|
||||
|
||||
def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"):
|
||||
def __init__(
|
||||
self,
|
||||
logprobs_mode: LogprobsMode = "raw_logprobs",
|
||||
use_fp64_gumbel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.topk_topp_sampler = TopKTopPSampler(logprobs_mode)
|
||||
self.topk_topp_sampler = TopKTopPSampler(logprobs_mode, use_fp64_gumbel)
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
self.logprobs_mode = logprobs_mode
|
||||
self.use_fp64_gumbel = use_fp64_gumbel
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -32,6 +32,10 @@ from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
|
||||
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, UniformTypeKVCacheSpecs
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import (
|
||||
empty_exponential_noise_like,
|
||||
sample_with_exponential_noise,
|
||||
)
|
||||
from vllm.v1.sample.sampler import _SAMPLING_EPS
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.utils import (
|
||||
@@ -113,6 +117,7 @@ class SpecDecodeBaseProposer:
|
||||
self.use_local_argmax_reduction: bool = (
|
||||
self.speculative_config.use_local_argmax_reduction
|
||||
)
|
||||
self.use_fp64_gumbel = vllm_config.model_config.use_fp64_gumbel
|
||||
|
||||
self.max_batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
||||
@@ -409,7 +414,9 @@ class SpecDecodeBaseProposer:
|
||||
return logits.argmax(dim=-1), None
|
||||
if sampling_metadata.all_greedy:
|
||||
return logits.argmax(dim=-1), None
|
||||
return compute_probs_and_sample_next_token(logits, sampling_metadata)
|
||||
return compute_probs_and_sample_next_token(
|
||||
logits, sampling_metadata, self.use_fp64_gumbel
|
||||
)
|
||||
|
||||
def _sample_draft_tokens(
|
||||
self,
|
||||
@@ -1656,6 +1663,7 @@ class SpecDecodeBaseProposer:
|
||||
def compute_probs_and_sample_next_token(
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
use_fp64_gumbel: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if sampling_metadata.all_greedy:
|
||||
# For greedy requests, draft_probs is not used in rejection sampling.
|
||||
@@ -1682,11 +1690,11 @@ def compute_probs_and_sample_next_token(
|
||||
# of the generated tokens after rejection sampling.
|
||||
|
||||
# TODO(woosuk): Consider seeds.
|
||||
q = torch.empty_like(probs)
|
||||
q = empty_exponential_noise_like(probs, use_fp64_gumbel)
|
||||
q.exponential_()
|
||||
# NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs
|
||||
# will be used later for rejection sampling.
|
||||
next_token_ids = probs.div(q).argmax(dim=-1).view(-1)
|
||||
next_token_ids = sample_with_exponential_noise(probs.clone(), q)
|
||||
if not sampling_metadata.all_random:
|
||||
greedy_token_ids = probs.argmax(dim=-1)
|
||||
next_token_ids = torch.where(is_greedy, greedy_token_ids, next_token_ids)
|
||||
|
||||
@@ -504,7 +504,10 @@ class GPUModelRunner(
|
||||
self.use_async_scheduling = self.scheduler_config.async_scheduling
|
||||
|
||||
# Sampler
|
||||
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
|
||||
self.sampler = Sampler(
|
||||
logprobs_mode=self.model_config.logprobs_mode,
|
||||
use_fp64_gumbel=self.model_config.use_fp64_gumbel,
|
||||
)
|
||||
|
||||
self.eplb_state: EplbState | None = None
|
||||
self._moe_model: MixtureOfExperts | None = None
|
||||
|
||||
Reference in New Issue
Block a user