mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Model Runner V2] Apply synthetic mode to probabilistic rejection sampler (#41035)
This commit is contained in:
@@ -106,13 +106,12 @@ steps:
|
||||
- vllm/v1/worker/gpu/
|
||||
- vllm/v1/worker/gpu_worker.py
|
||||
- tests/v1/spec_decode/test_max_len.py
|
||||
- tests/v1/spec_decode/test_probabilistic_rejection_sampler_utils.py
|
||||
- tests/v1/spec_decode/test_synthetic_rejection_sampler_utils.py
|
||||
- tests/v1/spec_decode/test_rejection_sampler_utils.py
|
||||
- tests/v1/e2e/spec_decode/test_spec_decode.py
|
||||
commands:
|
||||
- set -x
|
||||
- export VLLM_USE_V2_MODEL_RUNNER=1
|
||||
- pytest -v -s v1/spec_decode/test_max_len.py -k "eagle or mtp"
|
||||
- pytest -v -s v1/spec_decode/test_probabilistic_rejection_sampler_utils.py
|
||||
- pytest -v -s v1/spec_decode/test_rejection_sampler_utils.py
|
||||
- pytest -v -s v1/spec_decode/test_synthetic_rejection_sampler_utils.py
|
||||
- pytest -v -s v1/e2e/spec_decode/test_spec_decode.py -k "eagle or mtp"
|
||||
|
||||
+71
-4
@@ -6,8 +6,8 @@ import math
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.v1.worker.gpu.spec_decode.probabilistic_rejection_sampler_utils import (
|
||||
probabilistic_rejection_sample,
|
||||
from vllm.v1.worker.gpu.spec_decode.rejection_sampler_utils import (
|
||||
rejection_sample,
|
||||
)
|
||||
|
||||
VOCAB_SIZE = 4096
|
||||
@@ -167,7 +167,7 @@ def test_stochastic_rejection_sample(num_speculative_steps: int, temperature: fl
|
||||
num_trials=num_trials,
|
||||
)
|
||||
|
||||
sampled, num_sampled = probabilistic_rejection_sample(
|
||||
sampled, num_sampled = rejection_sample(
|
||||
**inputs, num_speculative_steps=num_speculative_steps
|
||||
)
|
||||
|
||||
@@ -201,7 +201,7 @@ def test_greedy_rejection_sample(num_speculative_steps: int):
|
||||
num_trials=num_trials,
|
||||
)
|
||||
|
||||
sampled, num_sampled = probabilistic_rejection_sample(
|
||||
sampled, num_sampled = rejection_sample(
|
||||
**inputs, num_speculative_steps=num_speculative_steps
|
||||
)
|
||||
|
||||
@@ -213,3 +213,70 @@ def test_greedy_rejection_sample(num_speculative_steps: int):
|
||||
assert (sampled[accepted_mask] == target_argmax).all(), (
|
||||
"Greedy sampling produced tokens that are not the target argmax"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_speculative_steps,temperature,unconditional_rates",
|
||||
[
|
||||
(3, 1.0, [0.9, 0.5, 0.2]),
|
||||
(3, 0.0, [0.9, 0.5, 0.2]),
|
||||
(3, 1.0, [1.0, 1.0, 1.0]),
|
||||
(3, 0.0, [1.0, 1.0, 1.0]),
|
||||
(3, 1.0, [0.0, 0.0, 0.0]),
|
||||
(3, 0.0, [0.0, 0.0, 0.0]),
|
||||
(1, 1.0, [0.7]),
|
||||
(1, 0.0, [0.7]),
|
||||
],
|
||||
)
|
||||
def test_synthetic_rejection_sample(
|
||||
num_speculative_steps: int,
|
||||
temperature: float,
|
||||
unconditional_rates: list[float],
|
||||
):
|
||||
"""
|
||||
Verify that synthetic rejection sampling produces the expected
|
||||
per-position acceptance rates. The unconditional rate at position i
|
||||
is P(all draft steps 0..i accepted) = product(conditional_rates[0:i+1]).
|
||||
This is approximately mean(num accepted >= i + 1) over many trials.
|
||||
"""
|
||||
from vllm.v1.spec_decode.utils import unconditional_to_conditional_rates
|
||||
|
||||
torch.manual_seed(42)
|
||||
device = "cuda"
|
||||
num_trials = 10 * VOCAB_SIZE
|
||||
deviation_tol = 1e-2
|
||||
|
||||
target_logits_1d = torch.randn(VOCAB_SIZE, device=device, dtype=torch.float32)
|
||||
draft_logits_1d = torch.randn(VOCAB_SIZE, device=device, dtype=torch.float32)
|
||||
|
||||
if temperature > 0:
|
||||
target_logits_1d /= temperature
|
||||
draft_logits_1d /= temperature
|
||||
|
||||
inputs = _build_rejection_sample_inputs(
|
||||
target_logits_1d,
|
||||
draft_logits_1d,
|
||||
num_speculative_steps,
|
||||
temperature=temperature,
|
||||
num_trials=num_trials,
|
||||
)
|
||||
|
||||
conditional_rates = unconditional_to_conditional_rates(unconditional_rates)
|
||||
synthetic_conditional_rates = torch.tensor(
|
||||
conditional_rates, dtype=torch.float32, device=device
|
||||
)
|
||||
|
||||
_, num_sampled = rejection_sample(
|
||||
**inputs,
|
||||
num_speculative_steps=num_speculative_steps,
|
||||
synthetic_conditional_rates=synthetic_conditional_rates,
|
||||
)
|
||||
|
||||
# num_sampled includes the resampled/bonus token.
|
||||
num_accepted = num_sampled - 1
|
||||
for i, expected_rate in enumerate(unconditional_rates):
|
||||
observed_rate = (num_accepted >= i + 1).float().mean().item()
|
||||
assert abs(observed_rate - expected_rate) < deviation_tol, (
|
||||
f"Step {i}: observed rate {observed_rate:.4f} deviates from "
|
||||
f"expected rate {expected_rate:.4f} by more than {deviation_tol}."
|
||||
)
|
||||
@@ -12,11 +12,8 @@ from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
|
||||
from vllm.v1.worker.gpu.sample.output import SamplerOutput
|
||||
from vllm.v1.worker.gpu.sample.sampler import Sampler
|
||||
from vllm.v1.worker.gpu.sample.states import NO_LOGPROBS
|
||||
from vllm.v1.worker.gpu.spec_decode.probabilistic_rejection_sampler_utils import (
|
||||
probabilistic_rejection_sample,
|
||||
)
|
||||
from vllm.v1.worker.gpu.spec_decode.synthetic_rejection_sampler_utils import (
|
||||
synthetic_rejection_sample,
|
||||
from vllm.v1.worker.gpu.spec_decode.rejection_sampler_utils import (
|
||||
rejection_sample,
|
||||
)
|
||||
|
||||
|
||||
@@ -101,59 +98,42 @@ class RejectionSampler:
|
||||
input_batch: InputBatch,
|
||||
draft_logits: torch.Tensor | None = None,
|
||||
) -> SamplerOutput:
|
||||
draft_sampled = input_batch.input_ids[input_batch.logits_indices]
|
||||
# NOTE(woosuk): We intentionally compute num_nans before sampling to make clear
|
||||
# that num_nans is computed before applying penalties and temperature.
|
||||
num_nans = get_num_nans(logits) if self.sampler.compute_nans else None
|
||||
|
||||
if self.rejection_sample_method == "standard":
|
||||
pos = input_batch.positions[input_batch.logits_indices]
|
||||
processed_logits = self.sampler.apply_sampling_params(
|
||||
logits,
|
||||
input_batch.expanded_idx_mapping,
|
||||
input_batch.idx_mapping_np,
|
||||
pos,
|
||||
draft_sampled,
|
||||
input_batch.expanded_local_pos,
|
||||
)
|
||||
sampled, num_sampled = probabilistic_rejection_sample(
|
||||
processed_logits,
|
||||
draft_logits,
|
||||
draft_sampled,
|
||||
input_batch.cu_num_logits,
|
||||
pos,
|
||||
input_batch.idx_mapping,
|
||||
input_batch.expanded_idx_mapping,
|
||||
input_batch.expanded_local_pos,
|
||||
self.sampler.sampling_states.temperature.gpu,
|
||||
self.sampler.sampling_states.seeds.gpu,
|
||||
self.num_speculative_steps,
|
||||
)
|
||||
logprobs_tensors = self._get_logprobs_tensors(
|
||||
input_batch,
|
||||
sampled,
|
||||
num_sampled,
|
||||
processed_logits
|
||||
if self.sampler.logprobs_mode == "processed_logprobs"
|
||||
else logits,
|
||||
)
|
||||
elif self.rejection_sample_method == "synthetic":
|
||||
sampler_output = self.sampler(logits, input_batch)
|
||||
logprobs_tensors = sampler_output.logprobs_tensors
|
||||
sampled, num_sampled = synthetic_rejection_sample(
|
||||
sampler_output.sampled_token_ids.view(-1),
|
||||
draft_sampled,
|
||||
input_batch.cu_num_logits,
|
||||
input_batch.positions[input_batch.logits_indices],
|
||||
input_batch.idx_mapping,
|
||||
self.sampler.sampling_states.seeds.gpu,
|
||||
self.synthetic_conditional_rates,
|
||||
self.num_speculative_steps,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown rejection sample method: {self.rejection_sample_method}"
|
||||
)
|
||||
draft_sampled = input_batch.input_ids[input_batch.logits_indices]
|
||||
pos = input_batch.positions[input_batch.logits_indices]
|
||||
processed_logits = self.sampler.apply_sampling_params(
|
||||
logits,
|
||||
input_batch.expanded_idx_mapping,
|
||||
input_batch.idx_mapping_np,
|
||||
pos,
|
||||
draft_sampled,
|
||||
input_batch.expanded_local_pos,
|
||||
)
|
||||
sampled, num_sampled = rejection_sample(
|
||||
processed_logits,
|
||||
draft_logits,
|
||||
draft_sampled,
|
||||
input_batch.cu_num_logits,
|
||||
pos,
|
||||
input_batch.idx_mapping,
|
||||
input_batch.expanded_idx_mapping,
|
||||
input_batch.expanded_local_pos,
|
||||
self.sampler.sampling_states.temperature.gpu,
|
||||
self.sampler.sampling_states.seeds.gpu,
|
||||
self.num_speculative_steps,
|
||||
self.synthetic_conditional_rates,
|
||||
)
|
||||
logprobs_tensors = self._get_logprobs_tensors(
|
||||
input_batch,
|
||||
sampled,
|
||||
num_sampled,
|
||||
processed_logits
|
||||
if self.sampler.logprobs_mode == "processed_logprobs"
|
||||
else logits,
|
||||
)
|
||||
|
||||
return SamplerOutput(
|
||||
sampled_token_ids=sampled,
|
||||
|
||||
+31
-9
@@ -154,7 +154,7 @@ def _compute_block_stats_kernel(
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _probabilistic_rejection_kernel(
|
||||
def _rejection_kernel(
|
||||
# [num_reqs, num_speculative_steps + 1]
|
||||
sampled_ptr,
|
||||
sampled_stride,
|
||||
@@ -198,9 +198,12 @@ def _probabilistic_rejection_kernel(
|
||||
seed_ptr,
|
||||
# [num_logits]
|
||||
pos_ptr,
|
||||
# [num_speculative_steps]
|
||||
synthetic_conditional_rates_ptr,
|
||||
vocab_num_blocks,
|
||||
PADDED_VOCAB_NUM_BLOCKS: tl.constexpr,
|
||||
HAS_DRAFT_LOGITS: tl.constexpr,
|
||||
SYNTHETIC_MODE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + req_idx)
|
||||
@@ -217,7 +220,7 @@ def _probabilistic_rejection_kernel(
|
||||
for i in range(num_tokens - 1):
|
||||
if accepted:
|
||||
logit_idx = start_idx + i
|
||||
draft_sampled = tl.load(draft_sampled_ptr + logit_idx + 1)
|
||||
draft_sampled = tl.load(draft_sampled_ptr + logit_idx + 1).to(tl.int64)
|
||||
if temp == 0.0:
|
||||
# Greedy sampling. Accept IFF draft matches target argmax.
|
||||
# NOTE: Target argmax is stored directly so that resampling
|
||||
@@ -236,9 +239,19 @@ def _probabilistic_rejection_kernel(
|
||||
target_local_argmax_ptr
|
||||
+ logit_idx * target_local_argmax_stride
|
||||
+ max_target_block_idx
|
||||
).to(tl.int64)
|
||||
|
||||
if SYNTHETIC_MODE:
|
||||
pos = tl.load(pos_ptr + logit_idx)
|
||||
u = tl_rand64(seed, pos, includes_zero=False)
|
||||
rate = tl.load(synthetic_conditional_rates_ptr + i)
|
||||
accepted &= u < rate
|
||||
else:
|
||||
accepted &= target_argmax == draft_sampled
|
||||
tl.store(
|
||||
sampled_ptr + req_idx * sampled_stride + i,
|
||||
draft_sampled if accepted else target_argmax,
|
||||
)
|
||||
accepted &= target_argmax == draft_sampled
|
||||
tl.store(sampled_ptr + req_idx * sampled_stride + i, target_argmax)
|
||||
else:
|
||||
target_logit = tl.load(
|
||||
target_logits_ptr + logit_idx * target_logits_stride + draft_sampled
|
||||
@@ -275,9 +288,14 @@ def _probabilistic_rejection_kernel(
|
||||
else:
|
||||
# One-hot draft: q(draft_token) = 1, log_q = 0.
|
||||
draft_log_prob = 0
|
||||
# Probability ratio test: p(x) > u * q(x)
|
||||
# Equivalent log form: log_p(x) > log(u) + log_q(x)
|
||||
accepted &= target_log_prob > tl.log(u) + draft_log_prob
|
||||
|
||||
if SYNTHETIC_MODE:
|
||||
rate = tl.load(synthetic_conditional_rates_ptr + i)
|
||||
accepted &= u < rate
|
||||
else:
|
||||
# Probability ratio test: p(x) > u * q(x)
|
||||
# Equivalent log form: log_p(x) > log(u) + log_q(x)
|
||||
accepted &= target_log_prob > tl.log(u) + draft_log_prob
|
||||
tl.store(sampled_ptr + req_idx * sampled_stride + i, draft_sampled)
|
||||
rejected_step += accepted
|
||||
tl.store(rejected_steps_ptr + req_idx, rejected_step)
|
||||
@@ -470,7 +488,7 @@ def _insert_resampled_kernel(
|
||||
)
|
||||
|
||||
|
||||
def probabilistic_rejection_sample(
|
||||
def rejection_sample(
|
||||
# [num_logits, V]
|
||||
target_logits: torch.Tensor,
|
||||
# [max_num_reqs, num_speculative_steps, V]
|
||||
@@ -492,6 +510,8 @@ def probabilistic_rejection_sample(
|
||||
# [max_num_reqs]
|
||||
seed: torch.Tensor,
|
||||
num_speculative_steps: int,
|
||||
# [num_speculative_steps]
|
||||
synthetic_conditional_rates: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
num_reqs = cu_num_logits.shape[0] - 1
|
||||
num_logits, vocab_size = target_logits.shape
|
||||
@@ -557,7 +577,7 @@ def probabilistic_rejection_sample(
|
||||
num_sampled = sampled.new_empty(num_reqs, dtype=torch.int32)
|
||||
target_rejected_logsumexp = target_logits.new_empty(num_reqs, dtype=torch.float32)
|
||||
draft_rejected_logsumexp = target_logits.new_empty(num_reqs, dtype=torch.float32)
|
||||
_probabilistic_rejection_kernel[(num_reqs,)](
|
||||
_rejection_kernel[(num_reqs,)](
|
||||
sampled,
|
||||
sampled.stride(0),
|
||||
num_sampled,
|
||||
@@ -584,9 +604,11 @@ def probabilistic_rejection_sample(
|
||||
temperature,
|
||||
seed,
|
||||
pos,
|
||||
synthetic_conditional_rates,
|
||||
vocab_num_blocks,
|
||||
PADDED_VOCAB_NUM_BLOCKS=padded_vocab_num_blocks,
|
||||
HAS_DRAFT_LOGITS=has_draft_logits,
|
||||
SYNTHETIC_MODE=synthetic_conditional_rates is not None,
|
||||
num_warps=1,
|
||||
)
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.worker.gpu.sample.gumbel import tl_rand64
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _synthetic_rejection_sample_kernel(
|
||||
# [num_reqs, num_speculative_steps + 1]
|
||||
sampled_ptr,
|
||||
sampled_stride,
|
||||
# [num_reqs]
|
||||
num_sampled_ptr,
|
||||
# [num_draft_tokens + num_reqs]
|
||||
target_sampled_ptr,
|
||||
# [num_draft_tokens + num_reqs]
|
||||
input_ids_ptr,
|
||||
# [num_reqs + 1]
|
||||
cu_num_logits_ptr,
|
||||
# [num_logits]
|
||||
pos_ptr,
|
||||
# [num_reqs]
|
||||
idx_mapping_ptr,
|
||||
# [max_num_reqs]
|
||||
seeds_ptr,
|
||||
# [num_speculative_steps]
|
||||
acceptance_rates_ptr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
start_idx = tl.load(cu_num_logits_ptr + req_idx)
|
||||
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
|
||||
num_tokens = end_idx - start_idx
|
||||
req_state_idx = tl.load(idx_mapping_ptr + req_idx)
|
||||
seed = tl.load(seeds_ptr + req_state_idx)
|
||||
|
||||
num_sampled = 0
|
||||
rejected = False
|
||||
for i in range(num_tokens - 1):
|
||||
if not rejected:
|
||||
logit_idx = start_idx + i
|
||||
pos = tl.load(pos_ptr + logit_idx)
|
||||
u = tl_rand64(seed, pos, includes_zero=False)
|
||||
acceptance_rate = tl.load(acceptance_rates_ptr + i)
|
||||
if u < acceptance_rate:
|
||||
sampled = tl.load(input_ids_ptr + logit_idx + 1).to(tl.int64)
|
||||
else:
|
||||
sampled = tl.load(target_sampled_ptr + logit_idx)
|
||||
rejected = True
|
||||
tl.store(sampled_ptr + req_idx * sampled_stride + i, sampled)
|
||||
num_sampled += 1
|
||||
if not rejected:
|
||||
target_sampled = tl.load(target_sampled_ptr + start_idx + num_tokens - 1)
|
||||
tl.store(
|
||||
sampled_ptr + req_idx * sampled_stride + num_tokens - 1, target_sampled
|
||||
)
|
||||
num_sampled += 1
|
||||
tl.store(num_sampled_ptr + req_idx, num_sampled)
|
||||
|
||||
|
||||
def synthetic_rejection_sample(
|
||||
# [num_draft_tokens + num_reqs]
|
||||
target_sampled: torch.Tensor,
|
||||
# [num_draft_tokens + num_reqs]
|
||||
draft_sampled: torch.Tensor,
|
||||
# [num_reqs + 1]
|
||||
cu_num_logits: torch.Tensor,
|
||||
# [num_logits]
|
||||
pos: torch.Tensor,
|
||||
# [num_reqs]
|
||||
idx_mapping: torch.Tensor,
|
||||
# [max_num_reqs]
|
||||
seed: torch.Tensor,
|
||||
# [num_speculative_steps]
|
||||
acceptance_rates: torch.Tensor,
|
||||
num_speculative_steps: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
num_reqs = cu_num_logits.shape[0] - 1
|
||||
sampled = target_sampled.new_empty(num_reqs, num_speculative_steps + 1)
|
||||
num_sampled = target_sampled.new_empty(num_reqs, dtype=torch.int32)
|
||||
_synthetic_rejection_sample_kernel[(num_reqs,)](
|
||||
sampled,
|
||||
sampled.stride(0),
|
||||
num_sampled,
|
||||
target_sampled,
|
||||
draft_sampled,
|
||||
cu_num_logits,
|
||||
pos,
|
||||
idx_mapping,
|
||||
seed,
|
||||
acceptance_rates,
|
||||
num_warps=1,
|
||||
)
|
||||
return sampled, num_sampled
|
||||
Reference in New Issue
Block a user