[None][feat] Add test for speculative rejection sampler (2-model) (#6542)

Signed-off-by: Izzy Putterman <iputterman@nvidia.com>
This commit is contained in:
Izzy Putterman 2025-08-13 19:09:35 -07:00 committed by GitHub
parent eb4ed18a63
commit ef53de8eef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 209 additions and 79 deletions

View File

@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Literal
from typing import Literal, Optional
import torch
@ -97,7 +97,9 @@ class EarlyStopSampler(Sampler):
request.py_result.append_context_logits(logits)
def top_k_sampling_batch(logits, top_k=50, generator: torch.Generator = None):
def top_k_sampling_batch(logits,
top_k=50,
generator: Optional[torch.Generator] = None):
logits_dim = logits.dim()
if logits_dim == 1:
logits = logits.unsqueeze(0)
@ -124,7 +126,7 @@ def top_k_sampling_batch(logits, top_k=50, generator: torch.Generator = None):
def top_p_sampling_batch(logits: torch.Tensor,
top_p: float = 0.9,
temperature: float = 1.0,
generator: torch.Generator = None):
generator: Optional[torch.Generator] = None):
logits_dim = logits.dim()
if logits_dim == 1:
logits = logits.unsqueeze(0)
@ -136,6 +138,50 @@ def top_p_sampling_batch(logits: torch.Tensor,
# sort the logits of each sample in descending order
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
# compute cumulative probability distribution of each sample
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1),
dim=-1)
# get the location of top_p
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = 0
# set the logits to -inf whose is outside top_p
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove)
logits = logits.masked_fill(indices_to_remove, float('-inf'))
# compute probability distribution
softmax = torch.softmax(logits, dim=-1)
# sample from the distribution and generate result of [batch_size, 1]
next_tokens = torch.multinomial(softmax, num_samples=1,
generator=generator).squeeze(-1)
return next_tokens, softmax
def top_k_top_p_sampling_batch(logits: torch.Tensor,
top_k: int,
top_p: float,
temperature: float = 1.0,
generator: Optional[torch.Generator] = None):
logits_dim = logits.dim()
if logits_dim == 1:
logits = logits.unsqueeze(0)
assert logits_dim == 2, "logits should be 2D: [batch_size, vocab_size]"
if temperature != 0:
logits = logits / max(temperature, 1e-5)
batch_size, vocab_size = logits.size()
# get first top_k logits of each sample and their indices
values, indices = torch.topk(logits, top_k, dim=-1)
min_values = values[:, -1].unsqueeze(-1).expand(batch_size, vocab_size)
# set the logits who is less than first top_k logits to -inf
logits = torch.where(logits < min_values,
torch.full_like(logits, float('-inf')), logits)
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
# compute cumulative probability distribution of each sample
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1),
dim=-1)
@ -165,14 +211,50 @@ def greedy_search_sampling_batch(logits):
return next_tokens, softmax
def get_rejected_indices(draft_probs: torch.Tensor, target_probs: torch.Tensor,
generator: torch.Generator, draft_tokens: list[int]):
p = draft_probs[torch.arange(len(draft_tokens)), draft_tokens]
q = target_probs[:-1]
q = q[torch.arange(len(draft_tokens)), draft_tokens]
accept_probs = torch.minimum(torch.ones(()), q / p)
# Use deterministic random generation for multi-GPU consistency
rejected_indices = (torch.rand(accept_probs.shape,
generator=generator,
device=accept_probs.device)
> accept_probs).nonzero()
return rejected_indices
def sample_rejected(draft_probs: torch.Tensor, target_probs: torch.Tensor,
generator: torch.Generator, num_accepted: int):
last_draft = draft_probs[num_accepted]
last_target = target_probs[num_accepted]
new = last_target - last_draft
new = torch.where(new > 0, new, 0.0)
new_token = torch.multinomial(new, num_samples=1,
generator=generator).squeeze(-1)
return new_token
TopK = tuple[Literal["top_k"], int]
TopP = tuple[Literal["top_p"], float, float]
TopKTopP = tuple[Literal["top_k_top_p"], int, float, float]
Greedy = tuple[Literal["greedy"], None]
GREEDY: Greedy = ("greedy", None)
Strategy = TopK | TopP | Greedy
def request_strategy(request: LlmRequest) -> Strategy:
if request.sampling_config.top_k is not None and len(
request.sampling_config.top_k
) > 0 and request.sampling_config.top_p is not None and len(
request.sampling_config.top_p) > 0:
return ("top_k_top_p", request.sampling_config.top_k[0],
request.sampling_config.top_p[0],
request.sampling_config.temperature[0])
if request.sampling_config.top_p is not None and len(
request.sampling_config.top_p) > 0:
return ("top_p", request.sampling_config.top_p[0],
@ -190,12 +272,15 @@ def sampling_strategies(requests: Iterable[LlmRequest]) -> list[Strategy]:
def sample(strategy: Strategy,
logits: torch.Tensor,
generator: torch.Generator = None):
generator: Optional[torch.Generator] = None):
match strategy:
case ("top_k", top_k):
return top_k_sampling_batch(logits, top_k, generator)
case ("top_p", top_p, temperature):
return top_p_sampling_batch(logits, top_p, temperature, generator)
case ("top_k_top_p", top_k, top_p, temperature):
return top_k_top_p_sampling_batch(logits, top_k, top_p, temperature,
generator)
case ("greedy", None):
return greedy_search_sampling_batch(logits)
@ -331,84 +416,77 @@ class TorchSampler(Sampler):
assert beam == 0, "The following call relies on beam_width to be 1 - hence the list with a single element"
request.py_result.append_log_probs([token_log_probs])
def _process_draft_tokens_greedy(self, request: LlmRequest,
new_tokens: torch.Tensor) -> int:
new_token = add_token(request, new_tokens, beam=self.BEAM)
stop = self._handle_stop_criteria(request, new_token)
if stop or get_draft_token_length(request) == 0:
return 0
num_accepted = 0
for draft_token in request.py_draft_tokens:
if draft_token != new_token:
# Reject.
break
num_accepted += 1
new_token = add_token(request,
new_tokens,
beam=self.BEAM,
step=num_accepted)
if self._handle_stop_criteria(request, new_token):
break
return num_accepted
def _process_draft_tokens_rejection_sampling(
self, request: LlmRequest, new_tokens: torch.Tensor) -> int:
sampling_strategy = request_strategy(request)
generator = self.get_generator(request.py_draft_logits.device)
_, draft_probs = sample(sampling_strategy,
request.py_draft_logits[0],
generator=generator)
target_probs = request.py_target_probs
rejected_indices = get_rejected_indices(draft_probs, target_probs,
generator,
request.py_draft_tokens)
sample_last = True
stop = False
if rejected_indices.numel() == 0:
num_initially_accepted = get_draft_token_length(request)
sample_last = False
else:
num_initially_accepted = rejected_indices[0].item()
num_accepted = num_initially_accepted
for i in range(num_accepted):
new_token = request.py_draft_tokens[i]
new_tokens[i, request.seq_slot, self.BEAM] = new_token
request.add_new_token(new_token, self.BEAM)
stop = self._handle_stop_criteria(request, new_token)
if stop:
num_accepted = i + 1
return num_accepted
if sample_last:
new_token = sample_rejected(draft_probs, target_probs, generator,
num_accepted)
new_tokens[num_accepted, request.seq_slot, self.BEAM] = new_token
request.add_new_token(new_token, self.BEAM)
stop = self._handle_stop_criteria(request, new_token)
else:
new_token = add_token(request,
new_tokens,
beam=self.BEAM,
step=num_accepted)
stop = self._handle_stop_criteria(request, new_token)
return num_accepted
def process_draft_tokens(self, request: LlmRequest,
new_tokens: torch.Tensor) -> int:
if request.py_draft_logits is None:
new_token = add_token(request, new_tokens, beam=self.BEAM)
stop = self._handle_stop_criteria(request, new_token)
if stop or get_draft_token_length(request) == 0:
return 0
num_accepted = 0
for draft_token in request.py_draft_tokens:
if draft_token != new_token:
# Reject.
break
num_accepted += 1
new_token = add_token(request,
new_tokens,
beam=self.BEAM,
step=num_accepted)
if self._handle_stop_criteria(request, new_token):
break
return num_accepted
return self._process_draft_tokens_greedy(request, new_tokens)
else:
sampling_strategy = request_strategy(request)
generator = self.get_generator(request.py_draft_logits.device)
_, draft_probs = sample(sampling_strategy,
request.py_draft_logits[0],
generator=generator)
target_probs = request.py_target_probs
p = draft_probs[torch.arange(get_draft_token_length(request)),
request.py_draft_tokens]
q = target_probs[:-1]
q = q[torch.arange(get_draft_token_length(request)),
request.py_draft_tokens]
accept_probs = torch.minimum(torch.ones(()), q / p)
# Use deterministic random generation for multi-GPU consistency
rejected_indices = (torch.rand(accept_probs.shape,
generator=generator,
device=accept_probs.device)
> accept_probs).nonzero()
sample_last = True
stop = False
if rejected_indices.numel() == 0:
num_initially_accepted = get_draft_token_length(request)
sample_last = False
else:
num_initially_accepted = rejected_indices[0].item()
num_accepted = num_initially_accepted
for i in range(num_accepted):
new_token = request.py_draft_tokens[i]
new_tokens[i, request.seq_slot, self.BEAM] = new_token
request.add_new_token(new_token, self.BEAM)
stop = self._handle_stop_criteria(request, new_token)
if stop:
num_accepted = i + 1
break
if not stop and sample_last:
last_draft = draft_probs[num_accepted]
last_target = target_probs[num_accepted]
new = last_target - last_draft
new = torch.where(new > 0, new, 0.0)
new_token = torch.multinomial(new,
num_samples=1,
generator=generator).squeeze(-1)
new_tokens[num_accepted, request.seq_slot,
self.BEAM] = new_token
request.add_new_token(new_token, self.BEAM)
stop = self._handle_stop_criteria(request, new_token)
elif not stop and not sample_last:
new_token = add_token(request,
new_tokens,
beam=self.BEAM,
step=num_accepted)
stop = self._handle_stop_criteria(request, new_token)
return num_accepted
return self._process_draft_tokens_rejection_sampling(
request, new_tokens)
def update_requests(self, state: SampleState) -> None:
assert isinstance(state, SampleState)
@ -601,7 +679,6 @@ class TorchSampler(Sampler):
current_slice = slice(0, steps), slot, beam
new_tokens[current_slice] = next_tokens
if request.py_draft_logits is not None:
# Could be cleaner
request.py_target_probs = softmax.clone()
if gen_logits_host is not None:
gen_logits_host[current_slice].copy_(logits, non_blocking=True)

View File

@ -0,0 +1,53 @@
import unittest
import numpy as np
import torch
from scipy.stats import entropy
from tensorrt_llm._torch.pyexecutor.sampler import (get_rejected_indices,
sample_rejected)
def test_get_rejected_indices():
vocab_size = 500
num_iter = 50000
draft_probs = torch.rand(1, vocab_size)
drop_idx = torch.topk(draft_probs[0], k=400, largest=False)[1]
draft_probs[0, drop_idx] = 0.0
draft_probs = draft_probs / draft_probs.sum(dim=-1, keepdim=True)
target_probs = torch.rand(2, vocab_size)
drop_idx = torch.topk(target_probs[0], k=400, largest=False)[1]
target_probs[0, drop_idx] = 0.0
target_probs = target_probs / target_probs.sum(dim=-1, keepdim=True)
generator = torch.Generator()
sampled_tokens = []
sampled_regular = []
for _ in range(num_iter):
draft_tokens = [
torch.multinomial(draft_probs, num_samples=1,
generator=generator).item()
]
rejected_indices = get_rejected_indices(draft_probs, target_probs,
generator, draft_tokens)
if rejected_indices.shape[0] == 0:
sampled_tokens.append(draft_tokens[0])
else:
sampled_tokens.append(
sample_rejected(draft_probs, target_probs, generator, 0).item())
sampled_regular.append(
torch.multinomial(target_probs[0],
num_samples=1,
generator=generator).item())
bins = np.arange(vocab_size + 1) - 0.5 # Bins for histogram
sampled_tokens, _ = np.histogram(sampled_tokens, bins=bins, density=True)
sampled_regular, _ = np.histogram(sampled_regular, bins=bins, density=True)
expected_prob = target_probs[0].squeeze().numpy()
# KL Divergence check
kl_divergence = entropy(expected_prob, sampled_tokens)
kl_divergence_regular = entropy(expected_prob, sampled_regular)
assert abs(kl_divergence - kl_divergence_regular) < 0.01
if __name__ == "__main__":
unittest.main()