mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-8414][chore] BREAKING CHANGE: refine sampling strategy selection (#8132)
Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
This commit is contained in:
parent
e98616512f
commit
8298e93bd8
@ -34,6 +34,7 @@ extend_skip_glob = [
|
||||
"tests/unittest/_torch/modeling/test_modeling_mistral.py",
|
||||
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
|
||||
"tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py",
|
||||
"tests/unittest/_torch/sampler/test_torch_sampler.py",
|
||||
]
|
||||
|
||||
[tool.yapf]
|
||||
@ -65,6 +66,7 @@ ignore_patterns = [
|
||||
"tests/unittest/_torch/modeling/test_modeling_mistral.py",
|
||||
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
|
||||
"tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py",
|
||||
"tests/unittest/_torch/sampler/test_torch_sampler.py",
|
||||
]
|
||||
|
||||
[tool.codespell]
|
||||
@ -144,6 +146,7 @@ include = [
|
||||
"tests/unittest/_torch/modeling/test_modeling_mistral.py",
|
||||
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
|
||||
"tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py",
|
||||
"tests/unittest/_torch/sampler/test_torch_sampler.py",
|
||||
]
|
||||
exclude = [
|
||||
"**3rdparty/**",
|
||||
|
||||
@ -234,8 +234,10 @@ class DemoEngine(ADEngine):
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
logits_shape = logits.shape
|
||||
logits = logits.view(-1, logits_shape[-1]) # sampling_batch expects 2D logits
|
||||
if isinstance(sampling_params.top_k, int):
|
||||
idx_next, probs = top_k_sampling_batch(logits, sampling_params.top_k)
|
||||
if isinstance(sampling_params.top_k, int) and sampling_params.top_k > 1:
|
||||
idx_next, probs = top_k_sampling_batch(
|
||||
logits, top_k=sampling_params.top_k, temperature=1.0
|
||||
)
|
||||
else:
|
||||
idx_next, probs = greedy_search_sampling_batch(logits)
|
||||
idx_next = idx_next.view(logits_shape[:-1])
|
||||
|
||||
@ -6,7 +6,7 @@ from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from itertools import repeat
|
||||
from typing import Any, List, Literal, Optional, cast
|
||||
from typing import Any, List, Literal, Optional, TypeVar, cast
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -26,6 +26,7 @@ from tensorrt_llm.bindings.internal.runtime import (BufferManager, CudaEvent,
|
||||
GptDecoderBatched)
|
||||
from tensorrt_llm.executor.result import Logprob
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.sampling_params import SamplingParams
|
||||
|
||||
from ..speculative.spec_tree_manager import SpecTreeManager
|
||||
from .finish_reason import FinishedState
|
||||
@ -195,84 +196,75 @@ class EarlyStopWithMMResult(Sampler):
|
||||
|
||||
def top_k_sampling_batch(
|
||||
logits,
|
||||
top_k=50,
|
||||
generator: Optional[torch.Generator] = None
|
||||
*,
|
||||
top_k: int,
|
||||
temperature: float,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
logits_dim = logits.dim()
|
||||
if logits_dim == 1:
|
||||
logits = logits.unsqueeze(0)
|
||||
# logits should be 2D :[batch_size, vocab_size]
|
||||
batch_size, vocab_size = logits.size()
|
||||
|
||||
# get first top_k logits of each sample and their indices
|
||||
if top_k > 0:
|
||||
values, indices = torch.topk(logits, top_k, dim=-1)
|
||||
min_values = values[:, -1].unsqueeze(-1).expand(batch_size, vocab_size)
|
||||
|
||||
# set the logits who is less than first top_k logits to -inf
|
||||
logits = torch.where(logits < min_values,
|
||||
torch.full_like(logits, float('-inf')), logits)
|
||||
|
||||
# compute probability distribution
|
||||
softmax = torch.softmax(logits, dim=-1)
|
||||
|
||||
# sample from the distribution and generate result of [batch_size, 1]
|
||||
next_tokens = torch.multinomial(softmax, num_samples=1,
|
||||
generator=generator).squeeze(-1)
|
||||
return next_tokens, softmax
|
||||
# NB: To be replaced by a more efficient implementation.
|
||||
return top_k_top_p_sampling_batch(
|
||||
logits,
|
||||
top_k=top_k,
|
||||
temperature=temperature,
|
||||
generator=generator,
|
||||
top_p=1,
|
||||
)
|
||||
|
||||
|
||||
def top_p_sampling_batch(
|
||||
logits: torch.Tensor,
|
||||
*,
|
||||
top_p: float = 0.9,
|
||||
temperature: float = 1.0,
|
||||
top_p: float,
|
||||
temperature: float,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# NB: To be replaced by a more efficient implementation.
|
||||
return top_k_top_p_sampling_batch(
|
||||
logits,
|
||||
top_p=top_p,
|
||||
top_k=logits.size(1),
|
||||
temperature=temperature,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
|
||||
def temperature_sampling_batch(
|
||||
logits: torch.Tensor,
|
||||
*,
|
||||
temperature: float,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# NB: To be replaced by a more efficient implementation.
|
||||
return top_k_top_p_sampling_batch(
|
||||
logits,
|
||||
top_p=1,
|
||||
top_k=logits.size(1),
|
||||
temperature=temperature,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
|
||||
def top_k_top_p_sampling_batch(
|
||||
logits: torch.Tensor,
|
||||
*,
|
||||
top_k: int,
|
||||
top_p: float,
|
||||
temperature: float,
|
||||
generator: Optional[torch.Generator] = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
logits_dim = logits.dim()
|
||||
assert logits_dim == 2, "logits should be 2D: [batch_size, vocab_size]"
|
||||
|
||||
if temperature != 0:
|
||||
logits = logits / max(temperature, 1e-5)
|
||||
|
||||
# sort the logits of each sample in descending order
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||
|
||||
# compute cumulative probability distribution of each sample
|
||||
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1),
|
||||
dim=-1)
|
||||
# get the location of top_p
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
|
||||
sorted_indices_to_remove[:, 0] = 0
|
||||
|
||||
# set the logits to -inf whose is outside top_p
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
1, sorted_indices, sorted_indices_to_remove)
|
||||
logits = logits.masked_fill(indices_to_remove, float('-inf'))
|
||||
|
||||
# compute probability distribution
|
||||
softmax = torch.softmax(logits, dim=-1)
|
||||
|
||||
# sample from the distribution and generate result of [batch_size, 1]
|
||||
next_tokens = torch.multinomial(softmax, num_samples=1,
|
||||
generator=generator).squeeze(-1)
|
||||
return next_tokens, softmax
|
||||
|
||||
|
||||
def top_k_top_p_sampling_batch(logits: torch.Tensor,
|
||||
*,
|
||||
top_k: int,
|
||||
top_p: float,
|
||||
temperature: float = 1.0,
|
||||
generator: Optional[torch.Generator] = None):
|
||||
logits_dim = logits.dim()
|
||||
assert logits_dim == 2, "logits should be 2D: [batch_size, vocab_size]"
|
||||
if temperature != 0:
|
||||
logits = logits / max(temperature, 1e-5)
|
||||
assert temperature > 0, "non-greedy sampling requires valid temperature"
|
||||
logits = logits / max(temperature, 1e-5)
|
||||
batch_size, vocab_size = logits.size()
|
||||
# get first top_k logits of each sample and their indices
|
||||
if top_k > 0:
|
||||
|
||||
assert top_k > 1, "non-greedy sampling requires valid top_k"
|
||||
need_top_k = top_k < vocab_size
|
||||
assert top_p > 0, "non-greedy sampling requires valid top_p"
|
||||
need_top_p = top_p < 1
|
||||
|
||||
# top-K: mask out logits not belonging to the top-K for each sample
|
||||
if need_top_k:
|
||||
values, _ = torch.topk(logits, top_k, dim=-1)
|
||||
min_values = values[:, -1].unsqueeze(-1).expand(batch_size, vocab_size)
|
||||
|
||||
@ -280,21 +272,28 @@ def top_k_top_p_sampling_batch(logits: torch.Tensor,
|
||||
logits = torch.where(logits < min_values,
|
||||
torch.full_like(logits, float('-inf')), logits)
|
||||
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||
# top-p: mask out logits outside the nucleus
|
||||
if need_top_p:
|
||||
sorted_logits, sorted_indices = torch.sort(logits,
|
||||
descending=True,
|
||||
dim=-1)
|
||||
|
||||
# compute cumulative probability distribution of each sample
|
||||
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1),
|
||||
dim=-1)
|
||||
# compute cumulative probability distribution of each sample
|
||||
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1),
|
||||
dim=-1)
|
||||
|
||||
# get the location of top_p
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
|
||||
sorted_indices_to_remove[:, 0] = 0
|
||||
# get the location of top_p
|
||||
# NB: Currently selecting the smallest index with cumulative_probs > top_p.
|
||||
# Thus, top_p -> 0 resembles greedy; agreement requires torch.sort(..., stable=True).
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove[:,
|
||||
1:] = sorted_indices_to_remove[:, :-1].clone()
|
||||
sorted_indices_to_remove[:, 0] = 0
|
||||
|
||||
# set the logits to -inf whose is outside top_p
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
1, sorted_indices, sorted_indices_to_remove)
|
||||
logits = logits.masked_fill(indices_to_remove, float('-inf'))
|
||||
# set the logits to -inf for token indices outside top_p
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
1, sorted_indices, sorted_indices_to_remove)
|
||||
logits = logits.masked_fill(indices_to_remove, float('-inf'))
|
||||
|
||||
# compute probability distribution
|
||||
softmax = torch.softmax(logits, dim=-1)
|
||||
@ -359,48 +358,100 @@ def sample_rejected(draft_probs: torch.Tensor, target_probs: torch.Tensor,
|
||||
return new_token
|
||||
|
||||
|
||||
TopK = tuple[Literal["top_k"], int]
|
||||
TemperatureOnly = tuple[Literal["temperature"], float]
|
||||
TopK = tuple[Literal["top_k"], int, float]
|
||||
TopP = tuple[Literal["top_p"], float, float]
|
||||
TopKTopP = tuple[Literal["top_k_top_p"], int, float, float]
|
||||
Greedy = tuple[Literal["greedy"], None]
|
||||
GREEDY: Greedy = ("greedy", None)
|
||||
Strategy = TopK | TopP | Greedy | TopKTopP
|
||||
Strategy = TopK | TopP | Greedy | TopKTopP | TemperatureOnly
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
def _request_strategy(request: LlmRequest) -> Strategy:
|
||||
# top_p and top_K with temperature=0.0 reduces to greedy
|
||||
# sampling
|
||||
temperature = request.sampling_config.temperature
|
||||
if temperature is not None:
|
||||
temperature = temperature[0]
|
||||
if temperature == 0.0:
|
||||
return GREEDY
|
||||
# Due to tensorrt_llm::runtime::SamplingConfig using vectors, params
|
||||
# in LlmRequest.sampling_params are either None or single-element lists.
|
||||
# This helper method simplifies code using such params.
|
||||
def _unwrap_singleton(p: Optional[List[T]]) -> Optional[T]:
|
||||
if p is None:
|
||||
return None
|
||||
t, = p
|
||||
return t
|
||||
|
||||
if request.sampling_config.top_k is not None and len(
|
||||
request.sampling_config.top_k
|
||||
) > 0 and request.sampling_config.top_p is not None and len(
|
||||
request.sampling_config.top_p) > 0:
|
||||
return ("top_k_top_p", request.sampling_config.top_k[0],
|
||||
request.sampling_config.top_p[0], temperature)
|
||||
elif request.sampling_config.top_p is not None and len(
|
||||
request.sampling_config.top_p) > 0:
|
||||
top_p = request.sampling_config.top_p[0]
|
||||
return ("top_p", top_p, temperature)
|
||||
elif request.sampling_config.top_k is not None and len(
|
||||
request.sampling_config.top_k) > 0:
|
||||
return ("top_k", request.sampling_config.top_k[0])
|
||||
else:
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class TorchSamplerSamplingParams:
|
||||
"""Subset of tensorrt_llm::runtime::SamplingConfig handled by TorchSampler."""
|
||||
temperature: Optional[float]
|
||||
top_p: Optional[float]
|
||||
top_k: Optional[int]
|
||||
|
||||
|
||||
def _request_get_sampling_params(
|
||||
request: LlmRequest) -> TorchSamplerSamplingParams:
|
||||
sampling_config = request.sampling_config
|
||||
temperature = _unwrap_singleton(
|
||||
cast(Optional[List[float]], sampling_config.temperature))
|
||||
top_p = _unwrap_singleton(cast(Optional[List[float]],
|
||||
sampling_config.top_p))
|
||||
top_k = _unwrap_singleton(cast(Optional[List[int]], sampling_config.top_k))
|
||||
|
||||
return TorchSamplerSamplingParams(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
|
||||
def _request_strategy(request: LlmRequest, *, vocab_size: int) -> Strategy:
|
||||
# The semantics are specified in the doc-string of SamplingParams
|
||||
|
||||
params = _request_get_sampling_params(request)
|
||||
temperature = params.temperature
|
||||
top_p = params.top_p
|
||||
top_k = params.top_k
|
||||
|
||||
if SamplingParams.params_imply_greedy_decoding(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
):
|
||||
return GREEDY
|
||||
|
||||
# --- resolving default values
|
||||
# NB: not greedy, hence temperature != 0 if specified
|
||||
temperature = temperature or 1.0
|
||||
|
||||
# NB: not greedy, hence top_p != 0 if specified
|
||||
top_p = top_p or 1.0
|
||||
# NB: not greedy, hence top_k != 1 if specified
|
||||
# (0 and vocab_size are equivalent)
|
||||
top_k = top_k or vocab_size
|
||||
|
||||
assert top_k > 1, "non-greedy sampling requires valid top_k"
|
||||
need_top_k = top_k < vocab_size
|
||||
assert top_p > 0, "non-greedy sampling requires valid top_p"
|
||||
need_top_p = top_p < 1
|
||||
|
||||
if need_top_p:
|
||||
if need_top_k:
|
||||
return ("top_k_top_p", top_k, top_p, temperature)
|
||||
return ("top_p", top_p, temperature)
|
||||
if need_top_k:
|
||||
return ("top_k", top_k, temperature)
|
||||
return ("temperature", temperature)
|
||||
|
||||
|
||||
def _group_requests_by_sampling_strategy(
|
||||
requests: Iterable[LlmRequest],
|
||||
*,
|
||||
pin_memory: bool = False) -> dict[Strategy, torch.Tensor]:
|
||||
pin_memory: bool = False,
|
||||
vocab_size: int) -> dict[Strategy, torch.Tensor]:
|
||||
# NB: Client code relies on request indices in returned torch.Tensor being sorted.
|
||||
strategy_dict: dict[Strategy, list[int]] = defaultdict(list)
|
||||
for req_index, req in enumerate(requests):
|
||||
strategy_dict[_request_strategy(req)].append(req_index)
|
||||
strategy_dict[_request_strategy(
|
||||
req, vocab_size=vocab_size)].append(req_index)
|
||||
return {
|
||||
strategy: torch.tensor(indices,
|
||||
pin_memory=pin_memory,
|
||||
@ -418,23 +469,32 @@ def sample(
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
filter_softmax = True
|
||||
match strategy:
|
||||
case ("top_k", top_k):
|
||||
tokens, softmax = top_k_sampling_batch(logits, top_k, generator)
|
||||
case ("top_k", top_k, temperature):
|
||||
tokens, softmax = top_k_sampling_batch(logits,
|
||||
top_k=top_k,
|
||||
temperature=temperature,
|
||||
generator=generator)
|
||||
case ("top_p", top_p, temperature):
|
||||
tokens, softmax = top_p_sampling_batch(
|
||||
logits,
|
||||
top_p=top_p,
|
||||
generator=generator,
|
||||
**(dict(temperature=temperature)
|
||||
if temperature is not None else dict()))
|
||||
temperature=temperature,
|
||||
)
|
||||
case ("top_k_top_p", top_k, top_p, temperature):
|
||||
tokens, softmax = top_k_top_p_sampling_batch(
|
||||
logits,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
generator=generator,
|
||||
**(dict(temperature=temperature)
|
||||
if temperature is not None else dict()))
|
||||
)
|
||||
case ("temperature", temperature):
|
||||
tokens, softmax = temperature_sampling_batch(
|
||||
logits,
|
||||
temperature=temperature,
|
||||
generator=generator,
|
||||
)
|
||||
case ("greedy", None):
|
||||
tokens, softmax = greedy_search_sampling_batch(
|
||||
logits, softmax_indices=softmax_indices)
|
||||
@ -1070,7 +1130,11 @@ class TorchSampler(Sampler):
|
||||
def _process_draft_tokens_rejection_sampling(
|
||||
self, request: LlmRequest, new_tokens: list[list[list[int]]],
|
||||
new_tokens_tensor: torch.Tensor) -> int:
|
||||
sampling_strategy = _request_strategy(request)
|
||||
# FIXME: Passing a dummy vocab_size could result in unnecessary
|
||||
# filtering of vocab_size logits, out of vocab_size in
|
||||
# total. The 'sample' below should generally be avoided
|
||||
# by retaining the draft_probs during drafting (TRTLLM-7772).
|
||||
sampling_strategy = _request_strategy(request, vocab_size=2**31)
|
||||
generator = self.get_generator(request.py_draft_logits.device)
|
||||
_, draft_probs = sample(sampling_strategy,
|
||||
request.py_draft_logits,
|
||||
@ -1329,7 +1393,7 @@ class TorchSampler(Sampler):
|
||||
dim=-1)
|
||||
|
||||
requests_by_strategy = _group_requests_by_sampling_strategy(
|
||||
requests, pin_memory=True)
|
||||
requests, pin_memory=True, vocab_size=logits_cuda.size(1))
|
||||
generator_cuda = self.get_generator(cuda_device)
|
||||
|
||||
# FIXME: This check should/could be performed in ModelDrafter.prepare_draft_tokens
|
||||
@ -1706,8 +1770,17 @@ class TorchSampler(Sampler):
|
||||
|
||||
@override
|
||||
def should_provide_draft_probs(self, request: LlmRequest) -> bool:
|
||||
params = _request_get_sampling_params(request)
|
||||
temperature = params.temperature
|
||||
top_p = params.top_p
|
||||
top_k = params.top_k
|
||||
|
||||
# Do not request draft probs when sampling is greedy.
|
||||
return _request_strategy(request) is not GREEDY
|
||||
return not SamplingParams.params_imply_greedy_decoding(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
|
||||
class Algorithms:
|
||||
|
||||
@ -154,13 +154,25 @@ class SamplingParams:
|
||||
best_of (int, optional): Number of sequences to consider for best output. Defaults to None.
|
||||
use_beam_search (bool): Whether to use beam search. Defaults to False.
|
||||
|
||||
top_k (int, optional): Controls number of logits to sample from. None means using C++ runtime default 0, i.e., all logits. Defaults to None.
|
||||
top_p (float, optional): Controls the top-P probability to sample from. None means using C++ runtime default 0.f. Defaults to None.
|
||||
top_k (int, optional): Controls number of logits to sample from. Can assume non-negative values, where 0 means 'all logits'. Defaults to None.
|
||||
The value None is treated as "not specified" in the following.
|
||||
If neither temperature, top_p, nor top_k are specified, sampling is greedy.
|
||||
If temperature > 0 and/or top_p < 1 are specified, sampling will proceed accordingly and top_k will default to top_k = 0.
|
||||
Setting top_k = 1 results in greedy sampling.
|
||||
top_p (float, optional): Controls the top-P probability to sample from. Can have values between 0 and 1. Defaults to None.
|
||||
The value None is treated as "not specified" in the following.
|
||||
If neither temperature, top_p, nor top_k are specified, sampling is greedy.
|
||||
If temperature > 0 and/or top_k > 1 are specified, sampling will proceed accordingly and top_p will default to top_p = 1.
|
||||
Setting top_p = 0 should result in greedy sampling, but is currently disallowed in the backend.
|
||||
top_p_min (float, optional): Controls decay in the top-P algorithm. topPMin is lower-bound. None means using C++ runtime default 1.e-6. Defaults to None.
|
||||
top_p_reset_ids (int, optional): Controls decay in the top-P algorithm. Indicates where to reset the decay. None means using C++ runtime default 1. Defaults to None.
|
||||
top_p_decay (float, optional): Controls decay in the top-P algorithm. The decay value. None means using C++ runtime default 1.f. Defaults to None.
|
||||
seed (int, optional): Controls the random seed used by the random number generator in sampling. None means using C++ runtime default 0. Defaults to None.
|
||||
temperature (float, optional): Controls the modulation of logits when sampling new tokens. It can have values > 0.f. None means using C++ runtime default 1.0f. Defaults to None.
|
||||
temperature (float, optional): Controls the modulation of logits when sampling new tokens. It can have values >= 0.f. Defaults to None.
|
||||
The value None is treated as "not specified" in the following.
|
||||
If neither temperature, top_p, nor top_k are specified, sampling is greedy.
|
||||
If top_p < 1 and/or top_k > 1 are specified, sampling will proceed accordingly and temperature will default to temperature = 1.
|
||||
Setting temperature = 0 results in greedy sampling.
|
||||
min_tokens (int, optional): Lower bound on the number of tokens to generate. Values < 1 have no effect. None means using C++ runtime default 1. Defaults to None.
|
||||
beam_search_diversity_rate (float, optional): Used to penalize tokens based on how often they appear in the sequence. It can have any value > 0.f. Values < 1.f encourages repetition, values > 1.f discourages it. None means using C++ runtime default 1.f. Defaults to None.
|
||||
repetition_penalty (float, optional): Used to penalize tokens based on how often they appear in the sequence. It can have any value > 0.f. Values < 1.f encourages repetition, values > 1.f discourages it. None means using C++ runtime default 1.f. Defaults to None.
|
||||
@ -296,11 +308,19 @@ class SamplingParams:
|
||||
For instance, while the greedy decoding with n > 1 is capable in the
|
||||
Executor class of C++ runtime, the LLM API disallows such combination.
|
||||
"""
|
||||
if self.best_of < self.n:
|
||||
if self.top_p is not None and (self.top_p < 0 or self.top_p > 1):
|
||||
raise ValueError(f"require 0 <= top_p <= 1, got top_p={self.top_p}")
|
||||
if self.top_k is not None and self.top_k < 0:
|
||||
raise ValueError(f"require top_k >= 0, got top_k={self.top_k}")
|
||||
if self.temperature is not None and self.temperature < 0:
|
||||
raise ValueError(f"require temperature >= 0, got temperature={self.temperature}")
|
||||
|
||||
if self.best_of is not None and self.best_of < self.n:
|
||||
raise ValueError(f"best_of ({self.best_of}) cannot be less than n ({self.n})")
|
||||
|
||||
if (
|
||||
self.best_of > 1
|
||||
self.best_of is not None
|
||||
and self.best_of > 1
|
||||
and self._greedy_decoding
|
||||
and not os.environ.get("TLLM_ALLOW_N_GREEDY_DECODING", None)
|
||||
):
|
||||
@ -324,12 +344,25 @@ class SamplingParams:
|
||||
self.logprobs = self.logprobs and int(self.logprobs)
|
||||
self.prompt_logprobs = self.prompt_logprobs and int(self.prompt_logprobs)
|
||||
|
||||
# NB: Static, because downstream code only holds instances of
|
||||
# bindings.SamplingConfig (not SamplingParams).
|
||||
@staticmethod
|
||||
def params_imply_greedy_decoding(
|
||||
*, temperature: Optional[float], top_p: Optional[float], top_k: Optional[int]
|
||||
):
|
||||
return (
|
||||
(temperature is None and top_p is None and top_k is None)
|
||||
or top_k == 1
|
||||
or top_p == 0.0
|
||||
or temperature == 0
|
||||
)
|
||||
|
||||
@property
|
||||
def _greedy_decoding(self) -> bool:
|
||||
return (
|
||||
not self.use_beam_search
|
||||
and (self.top_k is None or self.top_k == 1)
|
||||
and (self.top_p is None or self.top_p == 0.0)
|
||||
return not self.use_beam_search and self.params_imply_greedy_decoding(
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p,
|
||||
top_k=self.top_k,
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@ -14,6 +14,7 @@ l0_a10:
|
||||
backend: pytorch
|
||||
tests:
|
||||
# ------------- PyTorch tests ---------------
|
||||
- unittest/_torch/sampler/test_torch_sampler.py
|
||||
- unittest/_torch/modeling/test_modeling_mistral.py
|
||||
- unittest/_torch/modeling/test_modeling_pixtral.py
|
||||
# NOTE: this is a CPU-only test, but we do not have a dedicated job for this (and therefore no
|
||||
|
||||
303
tests/unittest/_torch/sampler/test_torch_sampler.py
Normal file
303
tests/unittest/_torch/sampler/test_torch_sampler.py
Normal file
@ -0,0 +1,303 @@
|
||||
from itertools import product
|
||||
from typing import Optional, cast
|
||||
|
||||
import pytest
|
||||
from utils.util import force_ampere
|
||||
|
||||
from tensorrt_llm._torch.pyexecutor.sampler import (
|
||||
GREEDY,
|
||||
LlmRequest,
|
||||
TorchSampler,
|
||||
_request_strategy,
|
||||
)
|
||||
from tensorrt_llm.bindings import SamplingConfig
|
||||
from tensorrt_llm.sampling_params import SamplingParams
|
||||
|
||||
|
||||
@force_ampere
|
||||
class TestStrategySelection:
|
||||
VOCAB_SIZE = 1000
|
||||
TOP_K_VALS = [None, 0, 1, 42, 1000]
|
||||
TOP_P_VALS = [None, 0, 0.42, 1]
|
||||
TEMPERATURE_VALS = [None, 0, 1.42]
|
||||
|
||||
# For non-greedy sampling, the following choices have no effect.
|
||||
TOP_P_NEUTRAL_VALS = [None, 1]
|
||||
TOP_K_NEUTRAL_VALS = [None, 0, VOCAB_SIZE]
|
||||
TEMPERATURE_NEUTRAL_VALS = [None, 1]
|
||||
|
||||
TEMPERATURE_NOT_GREEDY = [0.42] + [t for t in TEMPERATURE_NEUTRAL_VALS if t is not None]
|
||||
|
||||
class MockLlmRequest:
|
||||
sampling_config: SamplingConfig
|
||||
|
||||
def _check_params(self, params: SamplingParams):
|
||||
# cf. description of 'top_p' in doc-string of SamplingParams and
|
||||
# test_top_p_0_disallowed below.
|
||||
if params.top_p == 0:
|
||||
pytest.skip("top_p = 0 disallowed by tensorrt_llm::executor::SamplingConfig")
|
||||
|
||||
def test_top_p_0_disallowed(self):
|
||||
# If this xpasses, update _check_params and doc-string of SamplingParams.
|
||||
params = SamplingParams(top_p=0)
|
||||
pytest.xfail("top_p = 0 disallowed by tensorrt_llm::executor::SamplingConfig")
|
||||
params._get_sampling_config()
|
||||
|
||||
def _build_mock_llm_request(self, params: SamplingParams) -> LlmRequest:
|
||||
request = self.MockLlmRequest()
|
||||
request.sampling_config = SamplingConfig(params._get_sampling_config())
|
||||
return cast(LlmRequest, request)
|
||||
|
||||
def test_defaults(self):
|
||||
# NB: The code in _request_strategy relies on the default values below.
|
||||
default_params = SamplingParams()
|
||||
assert default_params.top_k is None
|
||||
assert default_params.top_p is None
|
||||
assert default_params.temperature is None
|
||||
|
||||
def test_defaults_config(self):
|
||||
# NB: The code in _request_strategy relies on the default values below.
|
||||
default_config = SamplingParams()._get_sampling_config()
|
||||
assert default_config.top_k is None
|
||||
assert default_config.top_p is None
|
||||
assert default_config.temperature is None
|
||||
|
||||
def test_defaults_request(self):
|
||||
# NB: The code in _request_strategy relies on the default values below.
|
||||
request = self._build_mock_llm_request(SamplingParams())
|
||||
default_config = request.sampling_config
|
||||
assert default_config.top_k is None
|
||||
assert default_config.top_p is None
|
||||
assert default_config.temperature is None
|
||||
|
||||
def test_default_is_greedy(self):
|
||||
request = self._build_mock_llm_request(SamplingParams())
|
||||
assert _request_strategy(request, vocab_size=self.VOCAB_SIZE) is GREEDY
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"top_p, top_k",
|
||||
[
|
||||
pytest.param(top_p, top_k)
|
||||
# https://stackoverflow.com/a/75421799, does not work with nested loops
|
||||
for (top_k, top_p) in product(TOP_K_VALS, TOP_P_VALS)
|
||||
],
|
||||
)
|
||||
def test_temperature_0_is_greedy(self, top_p: Optional[float], top_k: Optional[int]):
|
||||
params = SamplingParams(temperature=0, top_p=top_p, top_k=top_k)
|
||||
self._check_params(params)
|
||||
request = self._build_mock_llm_request(params)
|
||||
assert _request_strategy(request, vocab_size=self.VOCAB_SIZE) is GREEDY
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"temperature, top_k",
|
||||
[
|
||||
pytest.param(temperature, top_k)
|
||||
# https://stackoverflow.com/a/75421799, does not work with nested loops
|
||||
for (temperature, top_k) in product(TEMPERATURE_VALS, TOP_K_VALS)
|
||||
],
|
||||
)
|
||||
def test_top_p_0_is_greedy(self, temperature: Optional[float], top_k: Optional[int]):
|
||||
params = SamplingParams(top_p=0, temperature=temperature, top_k=top_k)
|
||||
self._check_params(params)
|
||||
request = self._build_mock_llm_request(params)
|
||||
assert _request_strategy(request, vocab_size=self.VOCAB_SIZE) is GREEDY
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"temperature, top_p",
|
||||
[
|
||||
pytest.param(temperature, top_p)
|
||||
# https://stackoverflow.com/a/75421799, does not work with nested loops
|
||||
for (temperature, top_p) in product(TEMPERATURE_VALS, TOP_P_VALS)
|
||||
],
|
||||
)
|
||||
def test_top_k_1_is_greedy(self, temperature: Optional[float], top_p: Optional[float]):
|
||||
params = SamplingParams(top_p=top_p, temperature=temperature, top_k=1)
|
||||
self._check_params(params)
|
||||
request = self._build_mock_llm_request(params)
|
||||
assert _request_strategy(request, vocab_size=self.VOCAB_SIZE) is GREEDY
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"temperature, trivial_top_p, trivial_top_k",
|
||||
[
|
||||
pytest.param(temperature, top_p, top_k)
|
||||
# https://stackoverflow.com/a/75421799, does not work with nested loops
|
||||
for (temperature, top_k, top_p) in product(
|
||||
TEMPERATURE_NOT_GREEDY, TOP_K_NEUTRAL_VALS, TOP_P_NEUTRAL_VALS
|
||||
)
|
||||
],
|
||||
)
|
||||
def test_temperature_only(
|
||||
self, temperature: float, trivial_top_p: Optional[float], trivial_top_k: Optional[int]
|
||||
):
|
||||
params = SamplingParams(temperature=temperature, top_p=trivial_top_p, top_k=trivial_top_k)
|
||||
self._check_params(params)
|
||||
request = self._build_mock_llm_request(params)
|
||||
strat = _request_strategy(request, vocab_size=self.VOCAB_SIZE)
|
||||
assert len(strat) == 2
|
||||
assert strat[0] == "temperature"
|
||||
assert strat[1] == pytest.approx(temperature)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"trivial_temperature, trivial_top_k",
|
||||
[
|
||||
pytest.param(temperature, top_k)
|
||||
# https://stackoverflow.com/a/75421799, does not work with nested loops
|
||||
for (temperature, top_k) in product(TEMPERATURE_NEUTRAL_VALS, TOP_K_NEUTRAL_VALS)
|
||||
],
|
||||
)
|
||||
def test_top_p_only(self, trivial_temperature: Optional[float], trivial_top_k: Optional[int]):
|
||||
params = SamplingParams(top_p=0.42, temperature=trivial_temperature, top_k=trivial_top_k)
|
||||
self._check_params(params)
|
||||
request = self._build_mock_llm_request(params)
|
||||
strat = _request_strategy(request, vocab_size=self.VOCAB_SIZE)
|
||||
assert len(strat) == 3
|
||||
assert strat[0] == "top_p"
|
||||
assert strat[1] == pytest.approx(0.42)
|
||||
assert strat[2] == pytest.approx(1.0)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"trivial_top_k",
|
||||
[
|
||||
pytest.param(top_k)
|
||||
for top_k in TOP_K_NEUTRAL_VALS # https://stackoverflow.com/a/75421799
|
||||
],
|
||||
)
|
||||
def test_top_p_with_temperature(self, trivial_top_k: Optional[int]):
|
||||
params = SamplingParams(top_p=0.42, temperature=0.9, top_k=trivial_top_k)
|
||||
self._check_params(params)
|
||||
request = self._build_mock_llm_request(params)
|
||||
strat = _request_strategy(request, vocab_size=self.VOCAB_SIZE)
|
||||
assert len(strat) == 3
|
||||
assert strat[0] == "top_p"
|
||||
assert strat[1] == pytest.approx(0.42)
|
||||
assert strat[2] == pytest.approx(0.9)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"trivial_temperature, trivial_top_p",
|
||||
[
|
||||
pytest.param(temperature, top_p)
|
||||
# https://stackoverflow.com/a/75421799, does not work with nested loops
|
||||
for (temperature, top_p) in product(TEMPERATURE_NEUTRAL_VALS, TOP_P_NEUTRAL_VALS)
|
||||
],
|
||||
)
|
||||
def test_top_k_only(self, trivial_temperature: Optional[float], trivial_top_p: Optional[float]):
|
||||
params = SamplingParams(top_k=42, temperature=trivial_temperature, top_p=trivial_top_p)
|
||||
self._check_params(params)
|
||||
request = self._build_mock_llm_request(params)
|
||||
strat = _request_strategy(request, vocab_size=self.VOCAB_SIZE)
|
||||
assert len(strat) == 3
|
||||
assert strat[0] == "top_k"
|
||||
assert strat[1] == 42
|
||||
assert strat[2] == pytest.approx(1.0)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"trivial_top_p",
|
||||
[
|
||||
pytest.param(top_p)
|
||||
for top_p in TOP_P_NEUTRAL_VALS # https://stackoverflow.com/a/75421799
|
||||
],
|
||||
)
|
||||
def test_top_k_with_temperature(self, trivial_top_p: Optional[float]):
|
||||
params = SamplingParams(top_k=42, temperature=0.9, top_p=trivial_top_p)
|
||||
self._check_params(params)
|
||||
request = self._build_mock_llm_request(params)
|
||||
strat = _request_strategy(request, vocab_size=self.VOCAB_SIZE)
|
||||
assert len(strat) == 3
|
||||
assert strat[0] == "top_k"
|
||||
assert strat[1] == 42
|
||||
assert strat[2] == pytest.approx(0.9)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"trivial_temperature",
|
||||
[
|
||||
pytest.param(temperature)
|
||||
for temperature in TEMPERATURE_NEUTRAL_VALS # https://stackoverflow.com/a/75421799
|
||||
],
|
||||
)
|
||||
def test_top_k_top_p(self, trivial_temperature: Optional[float]):
|
||||
params = SamplingParams(top_k=42, top_p=0.7, temperature=trivial_temperature)
|
||||
self._check_params(params)
|
||||
request = self._build_mock_llm_request(params)
|
||||
strat = _request_strategy(request, vocab_size=self.VOCAB_SIZE)
|
||||
assert len(strat) == 4
|
||||
assert strat[0] == "top_k_top_p"
|
||||
assert strat[1] == 42
|
||||
assert strat[2] == pytest.approx(0.7)
|
||||
assert strat[3] == pytest.approx(1.0)
|
||||
|
||||
def test_top_k_top_p_with_temperature(self):
|
||||
params = SamplingParams(top_k=42, top_p=0.7, temperature=0.9)
|
||||
self._check_params(params)
|
||||
request = self._build_mock_llm_request(params)
|
||||
strat = _request_strategy(request, vocab_size=self.VOCAB_SIZE)
|
||||
assert len(strat) == 4
|
||||
assert strat[0] == "top_k_top_p"
|
||||
assert strat[1] == 42
|
||||
assert strat[2] == pytest.approx(0.7)
|
||||
assert strat[3] == pytest.approx(0.9)
|
||||
|
||||
def test_param_validation(self):
|
||||
with pytest.raises(ValueError, match="require temperature >= 0, got temperature=-1"):
|
||||
SamplingParams(temperature=-1)
|
||||
|
||||
with pytest.raises(ValueError, match="require 0 <= top_p <= 1, got top_p=-1"):
|
||||
SamplingParams(top_p=-1)
|
||||
|
||||
with pytest.raises(ValueError, match="require 0 <= top_p <= 1, got top_p=2"):
|
||||
SamplingParams(top_p=2)
|
||||
|
||||
with pytest.raises(ValueError, match="require top_k >= 0, got top_k=-1"):
|
||||
SamplingParams(top_k=-1)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"top_k, top_p",
|
||||
[
|
||||
pytest.param(top_k, top_p)
|
||||
# https://stackoverflow.com/a/75421799, does not work with nested loops
|
||||
for (top_k, top_p) in product(TOP_K_NEUTRAL_VALS, TOP_P_NEUTRAL_VALS)
|
||||
if (top_k, top_p) != (None, None)
|
||||
],
|
||||
)
|
||||
def test_trivial_top_k_top_p_not_greedy(self, top_k: Optional[int], top_p: Optional[float]):
|
||||
params = SamplingParams(top_k=top_k, top_p=top_p)
|
||||
self._check_params(params)
|
||||
request = self._build_mock_llm_request(params)
|
||||
strat = _request_strategy(request, vocab_size=self.VOCAB_SIZE)
|
||||
assert len(strat) == 2
|
||||
assert strat[0] == "temperature"
|
||||
assert strat[1] == pytest.approx(1.0)
|
||||
|
||||
@pytest.fixture
|
||||
def torch_sampler(self) -> TorchSampler:
|
||||
return TorchSampler(
|
||||
TorchSampler.Args(
|
||||
max_seq_len=123,
|
||||
max_draft_len=3,
|
||||
max_num_sequences=12,
|
||||
max_beam_width=1,
|
||||
max_total_draft_tokens=3,
|
||||
)
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"temperature, top_p, top_k",
|
||||
[
|
||||
pytest.param(temperature, top_p, top_k)
|
||||
# https://stackoverflow.com/a/75421799, does not work with nested loops
|
||||
for (temperature, top_p, top_k) in product(TEMPERATURE_VALS, TOP_P_VALS, TOP_K_VALS)
|
||||
],
|
||||
)
|
||||
def test_should_provide_draft_probs_consistency(
|
||||
self,
|
||||
temperature: Optional[float],
|
||||
top_p: Optional[float],
|
||||
top_k: Optional[int],
|
||||
torch_sampler: TorchSampler,
|
||||
):
|
||||
params = SamplingParams(top_k=top_k, top_p=top_p, temperature=temperature)
|
||||
self._check_params(params)
|
||||
request = self._build_mock_llm_request(params)
|
||||
strat = _request_strategy(request, vocab_size=self.VOCAB_SIZE)
|
||||
is_greedy = strat is GREEDY
|
||||
|
||||
assert torch_sampler.should_provide_draft_probs(request) == (not is_greedy)
|
||||
@ -23,6 +23,7 @@ def create_llm(model_dir):
|
||||
enable_chunked_prefill=True,
|
||||
cuda_graph_config=CudaGraphConfig(),
|
||||
kv_cache_config=trt_kv_cache_config,
|
||||
sampler_type="TRTLLMSampler",
|
||||
max_num_tokens=
|
||||
128 # Only one request longer than max_num_tokens is required to test chunked prefill
|
||||
)
|
||||
|
||||
@ -18,6 +18,7 @@ import sys
|
||||
import traceback
|
||||
from typing import Any
|
||||
|
||||
import _pytest.outcomes
|
||||
import pytest
|
||||
import torch
|
||||
import tqdm
|
||||
@ -65,8 +66,9 @@ def pytest_pyfunc_call(pyfuncitem) -> Any:
|
||||
return (yield)
|
||||
# NB: _pytest.outcomes.OutcomeException subclasses BaseException
|
||||
except BaseException as e:
|
||||
print(f"TEST RAISED ERROR: {e}")
|
||||
traceback.print_exception(e)
|
||||
if not isinstance(e, _pytest.outcomes.Skipped):
|
||||
print(f"TEST RAISED ERROR: {e}")
|
||||
traceback.print_exception(e)
|
||||
raise
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user