diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index e17c029e70..d0681f18e2 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from collections.abc import Iterable from dataclasses import dataclass -from typing import Literal +from typing import Literal, Optional import torch @@ -97,7 +97,9 @@ class EarlyStopSampler(Sampler): request.py_result.append_context_logits(logits) -def top_k_sampling_batch(logits, top_k=50, generator: torch.Generator = None): +def top_k_sampling_batch(logits, + top_k=50, + generator: Optional[torch.Generator] = None): logits_dim = logits.dim() if logits_dim == 1: logits = logits.unsqueeze(0) @@ -124,7 +126,7 @@ def top_k_sampling_batch(logits, top_k=50, generator: torch.Generator = None): def top_p_sampling_batch(logits: torch.Tensor, top_p: float = 0.9, temperature: float = 1.0, - generator: torch.Generator = None): + generator: Optional[torch.Generator] = None): logits_dim = logits.dim() if logits_dim == 1: logits = logits.unsqueeze(0) @@ -136,6 +138,50 @@ def top_p_sampling_batch(logits: torch.Tensor, # sort the logits of each sample in descending order sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) + # compute cumulative probability distribution of each sample + cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), + dim=-1) + # get the location of top_p + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() + sorted_indices_to_remove[:, 0] = 0 + + # set the logits to -inf whose is outside top_p + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove) + logits = logits.masked_fill(indices_to_remove, float('-inf')) + + # compute probability distribution + softmax = torch.softmax(logits, dim=-1) + + # sample from the distribution and generate result of [batch_size, 1] + next_tokens = torch.multinomial(softmax, num_samples=1, + generator=generator).squeeze(-1) + return next_tokens, softmax + + +def top_k_top_p_sampling_batch(logits: torch.Tensor, + top_k: int, + top_p: float, + temperature: float = 1.0, + generator: Optional[torch.Generator] = None): + logits_dim = logits.dim() + if logits_dim == 1: + logits = logits.unsqueeze(0) + assert logits_dim == 2, "logits should be 2D: [batch_size, vocab_size]" + if temperature != 0: + logits = logits / max(temperature, 1e-5) + batch_size, vocab_size = logits.size() + # get first top_k logits of each sample and their indices + values, indices = torch.topk(logits, top_k, dim=-1) + min_values = values[:, -1].unsqueeze(-1).expand(batch_size, vocab_size) + + # set the logits who is less than first top_k logits to -inf + logits = torch.where(logits < min_values, + torch.full_like(logits, float('-inf')), logits) + + sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) + # compute cumulative probability distribution of each sample cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) @@ -165,14 +211,50 @@ def greedy_search_sampling_batch(logits): return next_tokens, softmax +def get_rejected_indices(draft_probs: torch.Tensor, target_probs: torch.Tensor, + generator: torch.Generator, draft_tokens: list[int]): + + p = draft_probs[torch.arange(len(draft_tokens)), draft_tokens] + q = target_probs[:-1] + q = q[torch.arange(len(draft_tokens)), draft_tokens] + accept_probs = torch.minimum(torch.ones(()), q / p) + # Use deterministic random generation for multi-GPU consistency + rejected_indices = (torch.rand(accept_probs.shape, + generator=generator, + device=accept_probs.device) + > accept_probs).nonzero() + return rejected_indices + + +def sample_rejected(draft_probs: torch.Tensor, target_probs: torch.Tensor, + generator: torch.Generator, num_accepted: int): + + last_draft = draft_probs[num_accepted] + last_target = target_probs[num_accepted] + new = last_target - last_draft + new = torch.where(new > 0, new, 0.0) + + new_token = torch.multinomial(new, num_samples=1, + generator=generator).squeeze(-1) + return new_token + + TopK = tuple[Literal["top_k"], int] TopP = tuple[Literal["top_p"], float, float] +TopKTopP = tuple[Literal["top_k_top_p"], int, float, float] Greedy = tuple[Literal["greedy"], None] GREEDY: Greedy = ("greedy", None) Strategy = TopK | TopP | Greedy def request_strategy(request: LlmRequest) -> Strategy: + if request.sampling_config.top_k is not None and len( + request.sampling_config.top_k + ) > 0 and request.sampling_config.top_p is not None and len( + request.sampling_config.top_p) > 0: + return ("top_k_top_p", request.sampling_config.top_k[0], + request.sampling_config.top_p[0], + request.sampling_config.temperature[0]) if request.sampling_config.top_p is not None and len( request.sampling_config.top_p) > 0: return ("top_p", request.sampling_config.top_p[0], @@ -190,12 +272,15 @@ def sampling_strategies(requests: Iterable[LlmRequest]) -> list[Strategy]: def sample(strategy: Strategy, logits: torch.Tensor, - generator: torch.Generator = None): + generator: Optional[torch.Generator] = None): match strategy: case ("top_k", top_k): return top_k_sampling_batch(logits, top_k, generator) case ("top_p", top_p, temperature): return top_p_sampling_batch(logits, top_p, temperature, generator) + case ("top_k_top_p", top_k, top_p, temperature): + return top_k_top_p_sampling_batch(logits, top_k, top_p, temperature, + generator) case ("greedy", None): return greedy_search_sampling_batch(logits) @@ -331,84 +416,77 @@ class TorchSampler(Sampler): assert beam == 0, "The following call relies on beam_width to be 1 - hence the list with a single element" request.py_result.append_log_probs([token_log_probs]) + def _process_draft_tokens_greedy(self, request: LlmRequest, + new_tokens: torch.Tensor) -> int: + new_token = add_token(request, new_tokens, beam=self.BEAM) + stop = self._handle_stop_criteria(request, new_token) + if stop or get_draft_token_length(request) == 0: + return 0 + num_accepted = 0 + + for draft_token in request.py_draft_tokens: + if draft_token != new_token: + # Reject. + break + + num_accepted += 1 + new_token = add_token(request, + new_tokens, + beam=self.BEAM, + step=num_accepted) + if self._handle_stop_criteria(request, new_token): + break + return num_accepted + + def _process_draft_tokens_rejection_sampling( + self, request: LlmRequest, new_tokens: torch.Tensor) -> int: + sampling_strategy = request_strategy(request) + generator = self.get_generator(request.py_draft_logits.device) + _, draft_probs = sample(sampling_strategy, + request.py_draft_logits[0], + generator=generator) + target_probs = request.py_target_probs + rejected_indices = get_rejected_indices(draft_probs, target_probs, + generator, + request.py_draft_tokens) + sample_last = True + stop = False + if rejected_indices.numel() == 0: + num_initially_accepted = get_draft_token_length(request) + sample_last = False + else: + num_initially_accepted = rejected_indices[0].item() + num_accepted = num_initially_accepted + for i in range(num_accepted): + new_token = request.py_draft_tokens[i] + new_tokens[i, request.seq_slot, self.BEAM] = new_token + request.add_new_token(new_token, self.BEAM) + stop = self._handle_stop_criteria(request, new_token) + if stop: + num_accepted = i + 1 + return num_accepted + if sample_last: + new_token = sample_rejected(draft_probs, target_probs, generator, + num_accepted) + new_tokens[num_accepted, request.seq_slot, self.BEAM] = new_token + request.add_new_token(new_token, self.BEAM) + stop = self._handle_stop_criteria(request, new_token) + else: + new_token = add_token(request, + new_tokens, + beam=self.BEAM, + step=num_accepted) + stop = self._handle_stop_criteria(request, new_token) + + return num_accepted + def process_draft_tokens(self, request: LlmRequest, new_tokens: torch.Tensor) -> int: if request.py_draft_logits is None: - new_token = add_token(request, new_tokens, beam=self.BEAM) - stop = self._handle_stop_criteria(request, new_token) - if stop or get_draft_token_length(request) == 0: - return 0 - num_accepted = 0 - - for draft_token in request.py_draft_tokens: - if draft_token != new_token: - # Reject. - break - - num_accepted += 1 - new_token = add_token(request, - new_tokens, - beam=self.BEAM, - step=num_accepted) - if self._handle_stop_criteria(request, new_token): - break - return num_accepted + return self._process_draft_tokens_greedy(request, new_tokens) else: - sampling_strategy = request_strategy(request) - generator = self.get_generator(request.py_draft_logits.device) - _, draft_probs = sample(sampling_strategy, - request.py_draft_logits[0], - generator=generator) - target_probs = request.py_target_probs - p = draft_probs[torch.arange(get_draft_token_length(request)), - request.py_draft_tokens] - q = target_probs[:-1] - q = q[torch.arange(get_draft_token_length(request)), - request.py_draft_tokens] - accept_probs = torch.minimum(torch.ones(()), q / p) - # Use deterministic random generation for multi-GPU consistency - rejected_indices = (torch.rand(accept_probs.shape, - generator=generator, - device=accept_probs.device) - > accept_probs).nonzero() - sample_last = True - stop = False - if rejected_indices.numel() == 0: - num_initially_accepted = get_draft_token_length(request) - sample_last = False - else: - num_initially_accepted = rejected_indices[0].item() - num_accepted = num_initially_accepted - for i in range(num_accepted): - new_token = request.py_draft_tokens[i] - new_tokens[i, request.seq_slot, self.BEAM] = new_token - request.add_new_token(new_token, self.BEAM) - stop = self._handle_stop_criteria(request, new_token) - if stop: - num_accepted = i + 1 - break - if not stop and sample_last: - last_draft = draft_probs[num_accepted] - last_target = target_probs[num_accepted] - new = last_target - last_draft - new = torch.where(new > 0, new, 0.0) - - new_token = torch.multinomial(new, - num_samples=1, - generator=generator).squeeze(-1) - - new_tokens[num_accepted, request.seq_slot, - self.BEAM] = new_token - request.add_new_token(new_token, self.BEAM) - stop = self._handle_stop_criteria(request, new_token) - elif not stop and not sample_last: - new_token = add_token(request, - new_tokens, - beam=self.BEAM, - step=num_accepted) - stop = self._handle_stop_criteria(request, new_token) - - return num_accepted + return self._process_draft_tokens_rejection_sampling( + request, new_tokens) def update_requests(self, state: SampleState) -> None: assert isinstance(state, SampleState) @@ -601,7 +679,6 @@ class TorchSampler(Sampler): current_slice = slice(0, steps), slot, beam new_tokens[current_slice] = next_tokens if request.py_draft_logits is not None: - # Could be cleaner request.py_target_probs = softmax.clone() if gen_logits_host is not None: gen_logits_host[current_slice].copy_(logits, non_blocking=True) diff --git a/tests/unittest/_torch/speculative/test_torch_rejection_sampling.py b/tests/unittest/_torch/speculative/test_torch_rejection_sampling.py new file mode 100644 index 0000000000..13960e528b --- /dev/null +++ b/tests/unittest/_torch/speculative/test_torch_rejection_sampling.py @@ -0,0 +1,53 @@ +import unittest + +import numpy as np +import torch +from scipy.stats import entropy + +from tensorrt_llm._torch.pyexecutor.sampler import (get_rejected_indices, + sample_rejected) + + +def test_get_rejected_indices(): + vocab_size = 500 + num_iter = 50000 + draft_probs = torch.rand(1, vocab_size) + drop_idx = torch.topk(draft_probs[0], k=400, largest=False)[1] + draft_probs[0, drop_idx] = 0.0 + draft_probs = draft_probs / draft_probs.sum(dim=-1, keepdim=True) + target_probs = torch.rand(2, vocab_size) + drop_idx = torch.topk(target_probs[0], k=400, largest=False)[1] + target_probs[0, drop_idx] = 0.0 + target_probs = target_probs / target_probs.sum(dim=-1, keepdim=True) + generator = torch.Generator() + sampled_tokens = [] + sampled_regular = [] + for _ in range(num_iter): + draft_tokens = [ + torch.multinomial(draft_probs, num_samples=1, + generator=generator).item() + ] + rejected_indices = get_rejected_indices(draft_probs, target_probs, + generator, draft_tokens) + if rejected_indices.shape[0] == 0: + sampled_tokens.append(draft_tokens[0]) + else: + sampled_tokens.append( + sample_rejected(draft_probs, target_probs, generator, 0).item()) + sampled_regular.append( + torch.multinomial(target_probs[0], + num_samples=1, + generator=generator).item()) + bins = np.arange(vocab_size + 1) - 0.5 # Bins for histogram + sampled_tokens, _ = np.histogram(sampled_tokens, bins=bins, density=True) + sampled_regular, _ = np.histogram(sampled_regular, bins=bins, density=True) + expected_prob = target_probs[0].squeeze().numpy() + + # KL Divergence check + kl_divergence = entropy(expected_prob, sampled_tokens) + kl_divergence_regular = entropy(expected_prob, sampled_regular) + assert abs(kl_divergence - kl_divergence_regular) < 0.01 + + +if __name__ == "__main__": + unittest.main()