mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-21 18:25:20 +08:00
[None][feat] Add test for speculative rejection sampler (2-model) (#6542)
Signed-off-by: Izzy Putterman <iputterman@nvidia.com>
This commit is contained in:
parent
eb4ed18a63
commit
ef53de8eef
@ -1,7 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
from typing import Literal, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -97,7 +97,9 @@ class EarlyStopSampler(Sampler):
|
||||
request.py_result.append_context_logits(logits)
|
||||
|
||||
|
||||
def top_k_sampling_batch(logits, top_k=50, generator: torch.Generator = None):
|
||||
def top_k_sampling_batch(logits,
|
||||
top_k=50,
|
||||
generator: Optional[torch.Generator] = None):
|
||||
logits_dim = logits.dim()
|
||||
if logits_dim == 1:
|
||||
logits = logits.unsqueeze(0)
|
||||
@ -124,7 +126,7 @@ def top_k_sampling_batch(logits, top_k=50, generator: torch.Generator = None):
|
||||
def top_p_sampling_batch(logits: torch.Tensor,
|
||||
top_p: float = 0.9,
|
||||
temperature: float = 1.0,
|
||||
generator: torch.Generator = None):
|
||||
generator: Optional[torch.Generator] = None):
|
||||
logits_dim = logits.dim()
|
||||
if logits_dim == 1:
|
||||
logits = logits.unsqueeze(0)
|
||||
@ -136,6 +138,50 @@ def top_p_sampling_batch(logits: torch.Tensor,
|
||||
# sort the logits of each sample in descending order
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||
|
||||
# compute cumulative probability distribution of each sample
|
||||
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1),
|
||||
dim=-1)
|
||||
# get the location of top_p
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
|
||||
sorted_indices_to_remove[:, 0] = 0
|
||||
|
||||
# set the logits to -inf whose is outside top_p
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
1, sorted_indices, sorted_indices_to_remove)
|
||||
logits = logits.masked_fill(indices_to_remove, float('-inf'))
|
||||
|
||||
# compute probability distribution
|
||||
softmax = torch.softmax(logits, dim=-1)
|
||||
|
||||
# sample from the distribution and generate result of [batch_size, 1]
|
||||
next_tokens = torch.multinomial(softmax, num_samples=1,
|
||||
generator=generator).squeeze(-1)
|
||||
return next_tokens, softmax
|
||||
|
||||
|
||||
def top_k_top_p_sampling_batch(logits: torch.Tensor,
|
||||
top_k: int,
|
||||
top_p: float,
|
||||
temperature: float = 1.0,
|
||||
generator: Optional[torch.Generator] = None):
|
||||
logits_dim = logits.dim()
|
||||
if logits_dim == 1:
|
||||
logits = logits.unsqueeze(0)
|
||||
assert logits_dim == 2, "logits should be 2D: [batch_size, vocab_size]"
|
||||
if temperature != 0:
|
||||
logits = logits / max(temperature, 1e-5)
|
||||
batch_size, vocab_size = logits.size()
|
||||
# get first top_k logits of each sample and their indices
|
||||
values, indices = torch.topk(logits, top_k, dim=-1)
|
||||
min_values = values[:, -1].unsqueeze(-1).expand(batch_size, vocab_size)
|
||||
|
||||
# set the logits who is less than first top_k logits to -inf
|
||||
logits = torch.where(logits < min_values,
|
||||
torch.full_like(logits, float('-inf')), logits)
|
||||
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||
|
||||
# compute cumulative probability distribution of each sample
|
||||
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1),
|
||||
dim=-1)
|
||||
@ -165,14 +211,50 @@ def greedy_search_sampling_batch(logits):
|
||||
return next_tokens, softmax
|
||||
|
||||
|
||||
def get_rejected_indices(draft_probs: torch.Tensor, target_probs: torch.Tensor,
|
||||
generator: torch.Generator, draft_tokens: list[int]):
|
||||
|
||||
p = draft_probs[torch.arange(len(draft_tokens)), draft_tokens]
|
||||
q = target_probs[:-1]
|
||||
q = q[torch.arange(len(draft_tokens)), draft_tokens]
|
||||
accept_probs = torch.minimum(torch.ones(()), q / p)
|
||||
# Use deterministic random generation for multi-GPU consistency
|
||||
rejected_indices = (torch.rand(accept_probs.shape,
|
||||
generator=generator,
|
||||
device=accept_probs.device)
|
||||
> accept_probs).nonzero()
|
||||
return rejected_indices
|
||||
|
||||
|
||||
def sample_rejected(draft_probs: torch.Tensor, target_probs: torch.Tensor,
|
||||
generator: torch.Generator, num_accepted: int):
|
||||
|
||||
last_draft = draft_probs[num_accepted]
|
||||
last_target = target_probs[num_accepted]
|
||||
new = last_target - last_draft
|
||||
new = torch.where(new > 0, new, 0.0)
|
||||
|
||||
new_token = torch.multinomial(new, num_samples=1,
|
||||
generator=generator).squeeze(-1)
|
||||
return new_token
|
||||
|
||||
|
||||
TopK = tuple[Literal["top_k"], int]
|
||||
TopP = tuple[Literal["top_p"], float, float]
|
||||
TopKTopP = tuple[Literal["top_k_top_p"], int, float, float]
|
||||
Greedy = tuple[Literal["greedy"], None]
|
||||
GREEDY: Greedy = ("greedy", None)
|
||||
Strategy = TopK | TopP | Greedy
|
||||
|
||||
|
||||
def request_strategy(request: LlmRequest) -> Strategy:
|
||||
if request.sampling_config.top_k is not None and len(
|
||||
request.sampling_config.top_k
|
||||
) > 0 and request.sampling_config.top_p is not None and len(
|
||||
request.sampling_config.top_p) > 0:
|
||||
return ("top_k_top_p", request.sampling_config.top_k[0],
|
||||
request.sampling_config.top_p[0],
|
||||
request.sampling_config.temperature[0])
|
||||
if request.sampling_config.top_p is not None and len(
|
||||
request.sampling_config.top_p) > 0:
|
||||
return ("top_p", request.sampling_config.top_p[0],
|
||||
@ -190,12 +272,15 @@ def sampling_strategies(requests: Iterable[LlmRequest]) -> list[Strategy]:
|
||||
|
||||
def sample(strategy: Strategy,
|
||||
logits: torch.Tensor,
|
||||
generator: torch.Generator = None):
|
||||
generator: Optional[torch.Generator] = None):
|
||||
match strategy:
|
||||
case ("top_k", top_k):
|
||||
return top_k_sampling_batch(logits, top_k, generator)
|
||||
case ("top_p", top_p, temperature):
|
||||
return top_p_sampling_batch(logits, top_p, temperature, generator)
|
||||
case ("top_k_top_p", top_k, top_p, temperature):
|
||||
return top_k_top_p_sampling_batch(logits, top_k, top_p, temperature,
|
||||
generator)
|
||||
case ("greedy", None):
|
||||
return greedy_search_sampling_batch(logits)
|
||||
|
||||
@ -331,84 +416,77 @@ class TorchSampler(Sampler):
|
||||
assert beam == 0, "The following call relies on beam_width to be 1 - hence the list with a single element"
|
||||
request.py_result.append_log_probs([token_log_probs])
|
||||
|
||||
def _process_draft_tokens_greedy(self, request: LlmRequest,
|
||||
new_tokens: torch.Tensor) -> int:
|
||||
new_token = add_token(request, new_tokens, beam=self.BEAM)
|
||||
stop = self._handle_stop_criteria(request, new_token)
|
||||
if stop or get_draft_token_length(request) == 0:
|
||||
return 0
|
||||
num_accepted = 0
|
||||
|
||||
for draft_token in request.py_draft_tokens:
|
||||
if draft_token != new_token:
|
||||
# Reject.
|
||||
break
|
||||
|
||||
num_accepted += 1
|
||||
new_token = add_token(request,
|
||||
new_tokens,
|
||||
beam=self.BEAM,
|
||||
step=num_accepted)
|
||||
if self._handle_stop_criteria(request, new_token):
|
||||
break
|
||||
return num_accepted
|
||||
|
||||
def _process_draft_tokens_rejection_sampling(
|
||||
self, request: LlmRequest, new_tokens: torch.Tensor) -> int:
|
||||
sampling_strategy = request_strategy(request)
|
||||
generator = self.get_generator(request.py_draft_logits.device)
|
||||
_, draft_probs = sample(sampling_strategy,
|
||||
request.py_draft_logits[0],
|
||||
generator=generator)
|
||||
target_probs = request.py_target_probs
|
||||
rejected_indices = get_rejected_indices(draft_probs, target_probs,
|
||||
generator,
|
||||
request.py_draft_tokens)
|
||||
sample_last = True
|
||||
stop = False
|
||||
if rejected_indices.numel() == 0:
|
||||
num_initially_accepted = get_draft_token_length(request)
|
||||
sample_last = False
|
||||
else:
|
||||
num_initially_accepted = rejected_indices[0].item()
|
||||
num_accepted = num_initially_accepted
|
||||
for i in range(num_accepted):
|
||||
new_token = request.py_draft_tokens[i]
|
||||
new_tokens[i, request.seq_slot, self.BEAM] = new_token
|
||||
request.add_new_token(new_token, self.BEAM)
|
||||
stop = self._handle_stop_criteria(request, new_token)
|
||||
if stop:
|
||||
num_accepted = i + 1
|
||||
return num_accepted
|
||||
if sample_last:
|
||||
new_token = sample_rejected(draft_probs, target_probs, generator,
|
||||
num_accepted)
|
||||
new_tokens[num_accepted, request.seq_slot, self.BEAM] = new_token
|
||||
request.add_new_token(new_token, self.BEAM)
|
||||
stop = self._handle_stop_criteria(request, new_token)
|
||||
else:
|
||||
new_token = add_token(request,
|
||||
new_tokens,
|
||||
beam=self.BEAM,
|
||||
step=num_accepted)
|
||||
stop = self._handle_stop_criteria(request, new_token)
|
||||
|
||||
return num_accepted
|
||||
|
||||
def process_draft_tokens(self, request: LlmRequest,
|
||||
new_tokens: torch.Tensor) -> int:
|
||||
if request.py_draft_logits is None:
|
||||
new_token = add_token(request, new_tokens, beam=self.BEAM)
|
||||
stop = self._handle_stop_criteria(request, new_token)
|
||||
if stop or get_draft_token_length(request) == 0:
|
||||
return 0
|
||||
num_accepted = 0
|
||||
|
||||
for draft_token in request.py_draft_tokens:
|
||||
if draft_token != new_token:
|
||||
# Reject.
|
||||
break
|
||||
|
||||
num_accepted += 1
|
||||
new_token = add_token(request,
|
||||
new_tokens,
|
||||
beam=self.BEAM,
|
||||
step=num_accepted)
|
||||
if self._handle_stop_criteria(request, new_token):
|
||||
break
|
||||
return num_accepted
|
||||
return self._process_draft_tokens_greedy(request, new_tokens)
|
||||
else:
|
||||
sampling_strategy = request_strategy(request)
|
||||
generator = self.get_generator(request.py_draft_logits.device)
|
||||
_, draft_probs = sample(sampling_strategy,
|
||||
request.py_draft_logits[0],
|
||||
generator=generator)
|
||||
target_probs = request.py_target_probs
|
||||
p = draft_probs[torch.arange(get_draft_token_length(request)),
|
||||
request.py_draft_tokens]
|
||||
q = target_probs[:-1]
|
||||
q = q[torch.arange(get_draft_token_length(request)),
|
||||
request.py_draft_tokens]
|
||||
accept_probs = torch.minimum(torch.ones(()), q / p)
|
||||
# Use deterministic random generation for multi-GPU consistency
|
||||
rejected_indices = (torch.rand(accept_probs.shape,
|
||||
generator=generator,
|
||||
device=accept_probs.device)
|
||||
> accept_probs).nonzero()
|
||||
sample_last = True
|
||||
stop = False
|
||||
if rejected_indices.numel() == 0:
|
||||
num_initially_accepted = get_draft_token_length(request)
|
||||
sample_last = False
|
||||
else:
|
||||
num_initially_accepted = rejected_indices[0].item()
|
||||
num_accepted = num_initially_accepted
|
||||
for i in range(num_accepted):
|
||||
new_token = request.py_draft_tokens[i]
|
||||
new_tokens[i, request.seq_slot, self.BEAM] = new_token
|
||||
request.add_new_token(new_token, self.BEAM)
|
||||
stop = self._handle_stop_criteria(request, new_token)
|
||||
if stop:
|
||||
num_accepted = i + 1
|
||||
break
|
||||
if not stop and sample_last:
|
||||
last_draft = draft_probs[num_accepted]
|
||||
last_target = target_probs[num_accepted]
|
||||
new = last_target - last_draft
|
||||
new = torch.where(new > 0, new, 0.0)
|
||||
|
||||
new_token = torch.multinomial(new,
|
||||
num_samples=1,
|
||||
generator=generator).squeeze(-1)
|
||||
|
||||
new_tokens[num_accepted, request.seq_slot,
|
||||
self.BEAM] = new_token
|
||||
request.add_new_token(new_token, self.BEAM)
|
||||
stop = self._handle_stop_criteria(request, new_token)
|
||||
elif not stop and not sample_last:
|
||||
new_token = add_token(request,
|
||||
new_tokens,
|
||||
beam=self.BEAM,
|
||||
step=num_accepted)
|
||||
stop = self._handle_stop_criteria(request, new_token)
|
||||
|
||||
return num_accepted
|
||||
return self._process_draft_tokens_rejection_sampling(
|
||||
request, new_tokens)
|
||||
|
||||
def update_requests(self, state: SampleState) -> None:
|
||||
assert isinstance(state, SampleState)
|
||||
@ -601,7 +679,6 @@ class TorchSampler(Sampler):
|
||||
current_slice = slice(0, steps), slot, beam
|
||||
new_tokens[current_slice] = next_tokens
|
||||
if request.py_draft_logits is not None:
|
||||
# Could be cleaner
|
||||
request.py_target_probs = softmax.clone()
|
||||
if gen_logits_host is not None:
|
||||
gen_logits_host[current_slice].copy_(logits, non_blocking=True)
|
||||
|
||||
@ -0,0 +1,53 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from scipy.stats import entropy
|
||||
|
||||
from tensorrt_llm._torch.pyexecutor.sampler import (get_rejected_indices,
|
||||
sample_rejected)
|
||||
|
||||
|
||||
def test_get_rejected_indices():
|
||||
vocab_size = 500
|
||||
num_iter = 50000
|
||||
draft_probs = torch.rand(1, vocab_size)
|
||||
drop_idx = torch.topk(draft_probs[0], k=400, largest=False)[1]
|
||||
draft_probs[0, drop_idx] = 0.0
|
||||
draft_probs = draft_probs / draft_probs.sum(dim=-1, keepdim=True)
|
||||
target_probs = torch.rand(2, vocab_size)
|
||||
drop_idx = torch.topk(target_probs[0], k=400, largest=False)[1]
|
||||
target_probs[0, drop_idx] = 0.0
|
||||
target_probs = target_probs / target_probs.sum(dim=-1, keepdim=True)
|
||||
generator = torch.Generator()
|
||||
sampled_tokens = []
|
||||
sampled_regular = []
|
||||
for _ in range(num_iter):
|
||||
draft_tokens = [
|
||||
torch.multinomial(draft_probs, num_samples=1,
|
||||
generator=generator).item()
|
||||
]
|
||||
rejected_indices = get_rejected_indices(draft_probs, target_probs,
|
||||
generator, draft_tokens)
|
||||
if rejected_indices.shape[0] == 0:
|
||||
sampled_tokens.append(draft_tokens[0])
|
||||
else:
|
||||
sampled_tokens.append(
|
||||
sample_rejected(draft_probs, target_probs, generator, 0).item())
|
||||
sampled_regular.append(
|
||||
torch.multinomial(target_probs[0],
|
||||
num_samples=1,
|
||||
generator=generator).item())
|
||||
bins = np.arange(vocab_size + 1) - 0.5 # Bins for histogram
|
||||
sampled_tokens, _ = np.histogram(sampled_tokens, bins=bins, density=True)
|
||||
sampled_regular, _ = np.histogram(sampled_regular, bins=bins, density=True)
|
||||
expected_prob = target_probs[0].squeeze().numpy()
|
||||
|
||||
# KL Divergence check
|
||||
kl_divergence = entropy(expected_prob, sampled_tokens)
|
||||
kl_divergence_regular = entropy(expected_prob, sampled_regular)
|
||||
assert abs(kl_divergence - kl_divergence_regular) < 0.01
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue
Block a user