[Model Runner V2] Apply synthetic mode to probabilistic rejection sampler (#41035)

This commit is contained in:
Giancarlo Delfin
2026-05-12 15:37:03 -05:00
committed by GitHub
parent 0ce6613b9c
commit fe5b4e0fe7
5 changed files with 138 additions and 165 deletions
+2 -3
View File
@@ -106,13 +106,12 @@ steps:
- vllm/v1/worker/gpu/
- vllm/v1/worker/gpu_worker.py
- tests/v1/spec_decode/test_max_len.py
- tests/v1/spec_decode/test_probabilistic_rejection_sampler_utils.py
- tests/v1/spec_decode/test_synthetic_rejection_sampler_utils.py
- tests/v1/spec_decode/test_rejection_sampler_utils.py
- tests/v1/e2e/spec_decode/test_spec_decode.py
commands:
- set -x
- export VLLM_USE_V2_MODEL_RUNNER=1
- pytest -v -s v1/spec_decode/test_max_len.py -k "eagle or mtp"
- pytest -v -s v1/spec_decode/test_probabilistic_rejection_sampler_utils.py
- pytest -v -s v1/spec_decode/test_rejection_sampler_utils.py
- pytest -v -s v1/spec_decode/test_synthetic_rejection_sampler_utils.py
- pytest -v -s v1/e2e/spec_decode/test_spec_decode.py -k "eagle or mtp"
@@ -6,8 +6,8 @@ import math
import pytest
import torch
from vllm.v1.worker.gpu.spec_decode.probabilistic_rejection_sampler_utils import (
probabilistic_rejection_sample,
from vllm.v1.worker.gpu.spec_decode.rejection_sampler_utils import (
rejection_sample,
)
VOCAB_SIZE = 4096
@@ -167,7 +167,7 @@ def test_stochastic_rejection_sample(num_speculative_steps: int, temperature: fl
num_trials=num_trials,
)
sampled, num_sampled = probabilistic_rejection_sample(
sampled, num_sampled = rejection_sample(
**inputs, num_speculative_steps=num_speculative_steps
)
@@ -201,7 +201,7 @@ def test_greedy_rejection_sample(num_speculative_steps: int):
num_trials=num_trials,
)
sampled, num_sampled = probabilistic_rejection_sample(
sampled, num_sampled = rejection_sample(
**inputs, num_speculative_steps=num_speculative_steps
)
@@ -213,3 +213,70 @@ def test_greedy_rejection_sample(num_speculative_steps: int):
assert (sampled[accepted_mask] == target_argmax).all(), (
"Greedy sampling produced tokens that are not the target argmax"
)
@pytest.mark.parametrize(
"num_speculative_steps,temperature,unconditional_rates",
[
(3, 1.0, [0.9, 0.5, 0.2]),
(3, 0.0, [0.9, 0.5, 0.2]),
(3, 1.0, [1.0, 1.0, 1.0]),
(3, 0.0, [1.0, 1.0, 1.0]),
(3, 1.0, [0.0, 0.0, 0.0]),
(3, 0.0, [0.0, 0.0, 0.0]),
(1, 1.0, [0.7]),
(1, 0.0, [0.7]),
],
)
def test_synthetic_rejection_sample(
num_speculative_steps: int,
temperature: float,
unconditional_rates: list[float],
):
"""
Verify that synthetic rejection sampling produces the expected
per-position acceptance rates. The unconditional rate at position i
is P(all draft steps 0..i accepted) = product(conditional_rates[0:i+1]).
This is approximately mean(num accepted >= i + 1) over many trials.
"""
from vllm.v1.spec_decode.utils import unconditional_to_conditional_rates
torch.manual_seed(42)
device = "cuda"
num_trials = 10 * VOCAB_SIZE
deviation_tol = 1e-2
target_logits_1d = torch.randn(VOCAB_SIZE, device=device, dtype=torch.float32)
draft_logits_1d = torch.randn(VOCAB_SIZE, device=device, dtype=torch.float32)
if temperature > 0:
target_logits_1d /= temperature
draft_logits_1d /= temperature
inputs = _build_rejection_sample_inputs(
target_logits_1d,
draft_logits_1d,
num_speculative_steps,
temperature=temperature,
num_trials=num_trials,
)
conditional_rates = unconditional_to_conditional_rates(unconditional_rates)
synthetic_conditional_rates = torch.tensor(
conditional_rates, dtype=torch.float32, device=device
)
_, num_sampled = rejection_sample(
**inputs,
num_speculative_steps=num_speculative_steps,
synthetic_conditional_rates=synthetic_conditional_rates,
)
# num_sampled includes the resampled/bonus token.
num_accepted = num_sampled - 1
for i, expected_rate in enumerate(unconditional_rates):
observed_rate = (num_accepted >= i + 1).float().mean().item()
assert abs(observed_rate - expected_rate) < deviation_tol, (
f"Step {i}: observed rate {observed_rate:.4f} deviates from "
f"expected rate {expected_rate:.4f} by more than {deviation_tol}."
)
@@ -12,11 +12,8 @@ from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
from vllm.v1.worker.gpu.sample.output import SamplerOutput
from vllm.v1.worker.gpu.sample.sampler import Sampler
from vllm.v1.worker.gpu.sample.states import NO_LOGPROBS
from vllm.v1.worker.gpu.spec_decode.probabilistic_rejection_sampler_utils import (
probabilistic_rejection_sample,
)
from vllm.v1.worker.gpu.spec_decode.synthetic_rejection_sampler_utils import (
synthetic_rejection_sample,
from vllm.v1.worker.gpu.spec_decode.rejection_sampler_utils import (
rejection_sample,
)
@@ -101,59 +98,42 @@ class RejectionSampler:
input_batch: InputBatch,
draft_logits: torch.Tensor | None = None,
) -> SamplerOutput:
draft_sampled = input_batch.input_ids[input_batch.logits_indices]
# NOTE(woosuk): We intentionally compute num_nans before sampling to make clear
# that num_nans is computed before applying penalties and temperature.
num_nans = get_num_nans(logits) if self.sampler.compute_nans else None
if self.rejection_sample_method == "standard":
pos = input_batch.positions[input_batch.logits_indices]
processed_logits = self.sampler.apply_sampling_params(
logits,
input_batch.expanded_idx_mapping,
input_batch.idx_mapping_np,
pos,
draft_sampled,
input_batch.expanded_local_pos,
)
sampled, num_sampled = probabilistic_rejection_sample(
processed_logits,
draft_logits,
draft_sampled,
input_batch.cu_num_logits,
pos,
input_batch.idx_mapping,
input_batch.expanded_idx_mapping,
input_batch.expanded_local_pos,
self.sampler.sampling_states.temperature.gpu,
self.sampler.sampling_states.seeds.gpu,
self.num_speculative_steps,
)
logprobs_tensors = self._get_logprobs_tensors(
input_batch,
sampled,
num_sampled,
processed_logits
if self.sampler.logprobs_mode == "processed_logprobs"
else logits,
)
elif self.rejection_sample_method == "synthetic":
sampler_output = self.sampler(logits, input_batch)
logprobs_tensors = sampler_output.logprobs_tensors
sampled, num_sampled = synthetic_rejection_sample(
sampler_output.sampled_token_ids.view(-1),
draft_sampled,
input_batch.cu_num_logits,
input_batch.positions[input_batch.logits_indices],
input_batch.idx_mapping,
self.sampler.sampling_states.seeds.gpu,
self.synthetic_conditional_rates,
self.num_speculative_steps,
)
else:
raise ValueError(
f"Unknown rejection sample method: {self.rejection_sample_method}"
)
draft_sampled = input_batch.input_ids[input_batch.logits_indices]
pos = input_batch.positions[input_batch.logits_indices]
processed_logits = self.sampler.apply_sampling_params(
logits,
input_batch.expanded_idx_mapping,
input_batch.idx_mapping_np,
pos,
draft_sampled,
input_batch.expanded_local_pos,
)
sampled, num_sampled = rejection_sample(
processed_logits,
draft_logits,
draft_sampled,
input_batch.cu_num_logits,
pos,
input_batch.idx_mapping,
input_batch.expanded_idx_mapping,
input_batch.expanded_local_pos,
self.sampler.sampling_states.temperature.gpu,
self.sampler.sampling_states.seeds.gpu,
self.num_speculative_steps,
self.synthetic_conditional_rates,
)
logprobs_tensors = self._get_logprobs_tensors(
input_batch,
sampled,
num_sampled,
processed_logits
if self.sampler.logprobs_mode == "processed_logprobs"
else logits,
)
return SamplerOutput(
sampled_token_ids=sampled,
@@ -154,7 +154,7 @@ def _compute_block_stats_kernel(
@triton.jit
def _probabilistic_rejection_kernel(
def _rejection_kernel(
# [num_reqs, num_speculative_steps + 1]
sampled_ptr,
sampled_stride,
@@ -198,9 +198,12 @@ def _probabilistic_rejection_kernel(
seed_ptr,
# [num_logits]
pos_ptr,
# [num_speculative_steps]
synthetic_conditional_rates_ptr,
vocab_num_blocks,
PADDED_VOCAB_NUM_BLOCKS: tl.constexpr,
HAS_DRAFT_LOGITS: tl.constexpr,
SYNTHETIC_MODE: tl.constexpr,
):
req_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + req_idx)
@@ -217,7 +220,7 @@ def _probabilistic_rejection_kernel(
for i in range(num_tokens - 1):
if accepted:
logit_idx = start_idx + i
draft_sampled = tl.load(draft_sampled_ptr + logit_idx + 1)
draft_sampled = tl.load(draft_sampled_ptr + logit_idx + 1).to(tl.int64)
if temp == 0.0:
# Greedy sampling. Accept IFF draft matches target argmax.
# NOTE: Target argmax is stored directly so that resampling
@@ -236,9 +239,19 @@ def _probabilistic_rejection_kernel(
target_local_argmax_ptr
+ logit_idx * target_local_argmax_stride
+ max_target_block_idx
).to(tl.int64)
if SYNTHETIC_MODE:
pos = tl.load(pos_ptr + logit_idx)
u = tl_rand64(seed, pos, includes_zero=False)
rate = tl.load(synthetic_conditional_rates_ptr + i)
accepted &= u < rate
else:
accepted &= target_argmax == draft_sampled
tl.store(
sampled_ptr + req_idx * sampled_stride + i,
draft_sampled if accepted else target_argmax,
)
accepted &= target_argmax == draft_sampled
tl.store(sampled_ptr + req_idx * sampled_stride + i, target_argmax)
else:
target_logit = tl.load(
target_logits_ptr + logit_idx * target_logits_stride + draft_sampled
@@ -275,9 +288,14 @@ def _probabilistic_rejection_kernel(
else:
# One-hot draft: q(draft_token) = 1, log_q = 0.
draft_log_prob = 0
# Probability ratio test: p(x) > u * q(x)
# Equivalent log form: log_p(x) > log(u) + log_q(x)
accepted &= target_log_prob > tl.log(u) + draft_log_prob
if SYNTHETIC_MODE:
rate = tl.load(synthetic_conditional_rates_ptr + i)
accepted &= u < rate
else:
# Probability ratio test: p(x) > u * q(x)
# Equivalent log form: log_p(x) > log(u) + log_q(x)
accepted &= target_log_prob > tl.log(u) + draft_log_prob
tl.store(sampled_ptr + req_idx * sampled_stride + i, draft_sampled)
rejected_step += accepted
tl.store(rejected_steps_ptr + req_idx, rejected_step)
@@ -470,7 +488,7 @@ def _insert_resampled_kernel(
)
def probabilistic_rejection_sample(
def rejection_sample(
# [num_logits, V]
target_logits: torch.Tensor,
# [max_num_reqs, num_speculative_steps, V]
@@ -492,6 +510,8 @@ def probabilistic_rejection_sample(
# [max_num_reqs]
seed: torch.Tensor,
num_speculative_steps: int,
# [num_speculative_steps]
synthetic_conditional_rates: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
num_reqs = cu_num_logits.shape[0] - 1
num_logits, vocab_size = target_logits.shape
@@ -557,7 +577,7 @@ def probabilistic_rejection_sample(
num_sampled = sampled.new_empty(num_reqs, dtype=torch.int32)
target_rejected_logsumexp = target_logits.new_empty(num_reqs, dtype=torch.float32)
draft_rejected_logsumexp = target_logits.new_empty(num_reqs, dtype=torch.float32)
_probabilistic_rejection_kernel[(num_reqs,)](
_rejection_kernel[(num_reqs,)](
sampled,
sampled.stride(0),
num_sampled,
@@ -584,9 +604,11 @@ def probabilistic_rejection_sample(
temperature,
seed,
pos,
synthetic_conditional_rates,
vocab_num_blocks,
PADDED_VOCAB_NUM_BLOCKS=padded_vocab_num_blocks,
HAS_DRAFT_LOGITS=has_draft_logits,
SYNTHETIC_MODE=synthetic_conditional_rates is not None,
num_warps=1,
)
@@ -1,95 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
from vllm.v1.worker.gpu.sample.gumbel import tl_rand64
@triton.jit
def _synthetic_rejection_sample_kernel(
# [num_reqs, num_speculative_steps + 1]
sampled_ptr,
sampled_stride,
# [num_reqs]
num_sampled_ptr,
# [num_draft_tokens + num_reqs]
target_sampled_ptr,
# [num_draft_tokens + num_reqs]
input_ids_ptr,
# [num_reqs + 1]
cu_num_logits_ptr,
# [num_logits]
pos_ptr,
# [num_reqs]
idx_mapping_ptr,
# [max_num_reqs]
seeds_ptr,
# [num_speculative_steps]
acceptance_rates_ptr,
):
req_idx = tl.program_id(0)
start_idx = tl.load(cu_num_logits_ptr + req_idx)
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
num_tokens = end_idx - start_idx
req_state_idx = tl.load(idx_mapping_ptr + req_idx)
seed = tl.load(seeds_ptr + req_state_idx)
num_sampled = 0
rejected = False
for i in range(num_tokens - 1):
if not rejected:
logit_idx = start_idx + i
pos = tl.load(pos_ptr + logit_idx)
u = tl_rand64(seed, pos, includes_zero=False)
acceptance_rate = tl.load(acceptance_rates_ptr + i)
if u < acceptance_rate:
sampled = tl.load(input_ids_ptr + logit_idx + 1).to(tl.int64)
else:
sampled = tl.load(target_sampled_ptr + logit_idx)
rejected = True
tl.store(sampled_ptr + req_idx * sampled_stride + i, sampled)
num_sampled += 1
if not rejected:
target_sampled = tl.load(target_sampled_ptr + start_idx + num_tokens - 1)
tl.store(
sampled_ptr + req_idx * sampled_stride + num_tokens - 1, target_sampled
)
num_sampled += 1
tl.store(num_sampled_ptr + req_idx, num_sampled)
def synthetic_rejection_sample(
# [num_draft_tokens + num_reqs]
target_sampled: torch.Tensor,
# [num_draft_tokens + num_reqs]
draft_sampled: torch.Tensor,
# [num_reqs + 1]
cu_num_logits: torch.Tensor,
# [num_logits]
pos: torch.Tensor,
# [num_reqs]
idx_mapping: torch.Tensor,
# [max_num_reqs]
seed: torch.Tensor,
# [num_speculative_steps]
acceptance_rates: torch.Tensor,
num_speculative_steps: int,
) -> tuple[torch.Tensor, torch.Tensor]:
num_reqs = cu_num_logits.shape[0] - 1
sampled = target_sampled.new_empty(num_reqs, num_speculative_steps + 1)
num_sampled = target_sampled.new_empty(num_reqs, dtype=torch.int32)
_synthetic_rejection_sample_kernel[(num_reqs,)](
sampled,
sampled.stride(0),
num_sampled,
target_sampled,
draft_sampled,
cu_num_logits,
pos,
idx_mapping,
seed,
acceptance_rates,
num_warps=1,
)
return sampled, num_sampled