[TRTLLM-8414][chore] BREAKING CHANGE: refine sampling strategy selection (#8132)

Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
This commit is contained in:
mpikulski 2025-10-08 15:46:50 +02:00 committed by GitHub
parent e98616512f
commit 8298e93bd8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 546 additions and 128 deletions

View File

@ -34,6 +34,7 @@ extend_skip_glob = [
"tests/unittest/_torch/modeling/test_modeling_mistral.py",
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
"tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py",
"tests/unittest/_torch/sampler/test_torch_sampler.py",
]
[tool.yapf]
@ -65,6 +66,7 @@ ignore_patterns = [
"tests/unittest/_torch/modeling/test_modeling_mistral.py",
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
"tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py",
"tests/unittest/_torch/sampler/test_torch_sampler.py",
]
[tool.codespell]
@ -144,6 +146,7 @@ include = [
"tests/unittest/_torch/modeling/test_modeling_mistral.py",
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
"tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py",
"tests/unittest/_torch/sampler/test_torch_sampler.py",
]
exclude = [
"**3rdparty/**",

View File

@ -234,8 +234,10 @@ class DemoEngine(ADEngine):
) -> Tuple[torch.Tensor, torch.Tensor]:
logits_shape = logits.shape
logits = logits.view(-1, logits_shape[-1]) # sampling_batch expects 2D logits
if isinstance(sampling_params.top_k, int):
idx_next, probs = top_k_sampling_batch(logits, sampling_params.top_k)
if isinstance(sampling_params.top_k, int) and sampling_params.top_k > 1:
idx_next, probs = top_k_sampling_batch(
logits, top_k=sampling_params.top_k, temperature=1.0
)
else:
idx_next, probs = greedy_search_sampling_batch(logits)
idx_next = idx_next.view(logits_shape[:-1])

View File

@ -6,7 +6,7 @@ from collections import defaultdict
from collections.abc import Iterable
from dataclasses import dataclass
from itertools import repeat
from typing import Any, List, Literal, Optional, cast
from typing import Any, List, Literal, Optional, TypeVar, cast
import torch
import torch.nn.functional as F
@ -26,6 +26,7 @@ from tensorrt_llm.bindings.internal.runtime import (BufferManager, CudaEvent,
GptDecoderBatched)
from tensorrt_llm.executor.result import Logprob
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.sampling_params import SamplingParams
from ..speculative.spec_tree_manager import SpecTreeManager
from .finish_reason import FinishedState
@ -195,84 +196,75 @@ class EarlyStopWithMMResult(Sampler):
def top_k_sampling_batch(
logits,
top_k=50,
generator: Optional[torch.Generator] = None
*,
top_k: int,
temperature: float,
generator: Optional[torch.Generator] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
logits_dim = logits.dim()
if logits_dim == 1:
logits = logits.unsqueeze(0)
# logits should be 2D [batch_size, vocab_size]
batch_size, vocab_size = logits.size()
# get first top_k logits of each sample and their indices
if top_k > 0:
values, indices = torch.topk(logits, top_k, dim=-1)
min_values = values[:, -1].unsqueeze(-1).expand(batch_size, vocab_size)
# set the logits who is less than first top_k logits to -inf
logits = torch.where(logits < min_values,
torch.full_like(logits, float('-inf')), logits)
# compute probability distribution
softmax = torch.softmax(logits, dim=-1)
# sample from the distribution and generate result of [batch_size, 1]
next_tokens = torch.multinomial(softmax, num_samples=1,
generator=generator).squeeze(-1)
return next_tokens, softmax
# NB: To be replaced by a more efficient implementation.
return top_k_top_p_sampling_batch(
logits,
top_k=top_k,
temperature=temperature,
generator=generator,
top_p=1,
)
def top_p_sampling_batch(
logits: torch.Tensor,
*,
top_p: float = 0.9,
temperature: float = 1.0,
top_p: float,
temperature: float,
generator: Optional[torch.Generator] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
# NB: To be replaced by a more efficient implementation.
return top_k_top_p_sampling_batch(
logits,
top_p=top_p,
top_k=logits.size(1),
temperature=temperature,
generator=generator,
)
def temperature_sampling_batch(
logits: torch.Tensor,
*,
temperature: float,
generator: Optional[torch.Generator] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
# NB: To be replaced by a more efficient implementation.
return top_k_top_p_sampling_batch(
logits,
top_p=1,
top_k=logits.size(1),
temperature=temperature,
generator=generator,
)
def top_k_top_p_sampling_batch(
logits: torch.Tensor,
*,
top_k: int,
top_p: float,
temperature: float,
generator: Optional[torch.Generator] = None
) -> tuple[torch.Tensor, torch.Tensor]:
logits_dim = logits.dim()
assert logits_dim == 2, "logits should be 2D: [batch_size, vocab_size]"
if temperature != 0:
logits = logits / max(temperature, 1e-5)
# sort the logits of each sample in descending order
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
# compute cumulative probability distribution of each sample
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1),
dim=-1)
# get the location of top_p
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = 0
# set the logits to -inf whose is outside top_p
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove)
logits = logits.masked_fill(indices_to_remove, float('-inf'))
# compute probability distribution
softmax = torch.softmax(logits, dim=-1)
# sample from the distribution and generate result of [batch_size, 1]
next_tokens = torch.multinomial(softmax, num_samples=1,
generator=generator).squeeze(-1)
return next_tokens, softmax
def top_k_top_p_sampling_batch(logits: torch.Tensor,
*,
top_k: int,
top_p: float,
temperature: float = 1.0,
generator: Optional[torch.Generator] = None):
logits_dim = logits.dim()
assert logits_dim == 2, "logits should be 2D: [batch_size, vocab_size]"
if temperature != 0:
logits = logits / max(temperature, 1e-5)
assert temperature > 0, "non-greedy sampling requires valid temperature"
logits = logits / max(temperature, 1e-5)
batch_size, vocab_size = logits.size()
# get first top_k logits of each sample and their indices
if top_k > 0:
assert top_k > 1, "non-greedy sampling requires valid top_k"
need_top_k = top_k < vocab_size
assert top_p > 0, "non-greedy sampling requires valid top_p"
need_top_p = top_p < 1
# top-K: mask out logits not belonging to the top-K for each sample
if need_top_k:
values, _ = torch.topk(logits, top_k, dim=-1)
min_values = values[:, -1].unsqueeze(-1).expand(batch_size, vocab_size)
@ -280,21 +272,28 @@ def top_k_top_p_sampling_batch(logits: torch.Tensor,
logits = torch.where(logits < min_values,
torch.full_like(logits, float('-inf')), logits)
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
# top-p: mask out logits outside the nucleus
if need_top_p:
sorted_logits, sorted_indices = torch.sort(logits,
descending=True,
dim=-1)
# compute cumulative probability distribution of each sample
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1),
dim=-1)
# compute cumulative probability distribution of each sample
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1),
dim=-1)
# get the location of top_p
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = 0
# get the location of top_p
# NB: Currently selecting the smallest index with cumulative_probs > top_p.
# Thus, top_p -> 0 resembles greedy; agreement requires torch.sort(..., stable=True).
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[:,
1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = 0
# set the logits to -inf whose is outside top_p
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove)
logits = logits.masked_fill(indices_to_remove, float('-inf'))
# set the logits to -inf for token indices outside top_p
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove)
logits = logits.masked_fill(indices_to_remove, float('-inf'))
# compute probability distribution
softmax = torch.softmax(logits, dim=-1)
@ -359,48 +358,100 @@ def sample_rejected(draft_probs: torch.Tensor, target_probs: torch.Tensor,
return new_token
TopK = tuple[Literal["top_k"], int]
TemperatureOnly = tuple[Literal["temperature"], float]
TopK = tuple[Literal["top_k"], int, float]
TopP = tuple[Literal["top_p"], float, float]
TopKTopP = tuple[Literal["top_k_top_p"], int, float, float]
Greedy = tuple[Literal["greedy"], None]
GREEDY: Greedy = ("greedy", None)
Strategy = TopK | TopP | Greedy | TopKTopP
Strategy = TopK | TopP | Greedy | TopKTopP | TemperatureOnly
T = TypeVar('T')
def _request_strategy(request: LlmRequest) -> Strategy:
# top_p and top_K with temperature=0.0 reduces to greedy
# sampling
temperature = request.sampling_config.temperature
if temperature is not None:
temperature = temperature[0]
if temperature == 0.0:
return GREEDY
# Due to tensorrt_llm::runtime::SamplingConfig using vectors, params
# in LlmRequest.sampling_params are either None or single-element lists.
# This helper method simplifies code using such params.
def _unwrap_singleton(p: Optional[List[T]]) -> Optional[T]:
if p is None:
return None
t, = p
return t
if request.sampling_config.top_k is not None and len(
request.sampling_config.top_k
) > 0 and request.sampling_config.top_p is not None and len(
request.sampling_config.top_p) > 0:
return ("top_k_top_p", request.sampling_config.top_k[0],
request.sampling_config.top_p[0], temperature)
elif request.sampling_config.top_p is not None and len(
request.sampling_config.top_p) > 0:
top_p = request.sampling_config.top_p[0]
return ("top_p", top_p, temperature)
elif request.sampling_config.top_k is not None and len(
request.sampling_config.top_k) > 0:
return ("top_k", request.sampling_config.top_k[0])
else:
@dataclass(frozen=True, kw_only=True)
class TorchSamplerSamplingParams:
"""Subset of tensorrt_llm::runtime::SamplingConfig handled by TorchSampler."""
temperature: Optional[float]
top_p: Optional[float]
top_k: Optional[int]
def _request_get_sampling_params(
request: LlmRequest) -> TorchSamplerSamplingParams:
sampling_config = request.sampling_config
temperature = _unwrap_singleton(
cast(Optional[List[float]], sampling_config.temperature))
top_p = _unwrap_singleton(cast(Optional[List[float]],
sampling_config.top_p))
top_k = _unwrap_singleton(cast(Optional[List[int]], sampling_config.top_k))
return TorchSamplerSamplingParams(
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
def _request_strategy(request: LlmRequest, *, vocab_size: int) -> Strategy:
# The semantics are specified in the doc-string of SamplingParams
params = _request_get_sampling_params(request)
temperature = params.temperature
top_p = params.top_p
top_k = params.top_k
if SamplingParams.params_imply_greedy_decoding(
temperature=temperature,
top_p=top_p,
top_k=top_k,
):
return GREEDY
# --- resolving default values
# NB: not greedy, hence temperature != 0 if specified
temperature = temperature or 1.0
# NB: not greedy, hence top_p != 0 if specified
top_p = top_p or 1.0
# NB: not greedy, hence top_k != 1 if specified
# (0 and vocab_size are equivalent)
top_k = top_k or vocab_size
assert top_k > 1, "non-greedy sampling requires valid top_k"
need_top_k = top_k < vocab_size
assert top_p > 0, "non-greedy sampling requires valid top_p"
need_top_p = top_p < 1
if need_top_p:
if need_top_k:
return ("top_k_top_p", top_k, top_p, temperature)
return ("top_p", top_p, temperature)
if need_top_k:
return ("top_k", top_k, temperature)
return ("temperature", temperature)
def _group_requests_by_sampling_strategy(
requests: Iterable[LlmRequest],
*,
pin_memory: bool = False) -> dict[Strategy, torch.Tensor]:
pin_memory: bool = False,
vocab_size: int) -> dict[Strategy, torch.Tensor]:
# NB: Client code relies on request indices in returned torch.Tensor being sorted.
strategy_dict: dict[Strategy, list[int]] = defaultdict(list)
for req_index, req in enumerate(requests):
strategy_dict[_request_strategy(req)].append(req_index)
strategy_dict[_request_strategy(
req, vocab_size=vocab_size)].append(req_index)
return {
strategy: torch.tensor(indices,
pin_memory=pin_memory,
@ -418,23 +469,32 @@ def sample(
) -> tuple[torch.Tensor, torch.Tensor]:
filter_softmax = True
match strategy:
case ("top_k", top_k):
tokens, softmax = top_k_sampling_batch(logits, top_k, generator)
case ("top_k", top_k, temperature):
tokens, softmax = top_k_sampling_batch(logits,
top_k=top_k,
temperature=temperature,
generator=generator)
case ("top_p", top_p, temperature):
tokens, softmax = top_p_sampling_batch(
logits,
top_p=top_p,
generator=generator,
**(dict(temperature=temperature)
if temperature is not None else dict()))
temperature=temperature,
)
case ("top_k_top_p", top_k, top_p, temperature):
tokens, softmax = top_k_top_p_sampling_batch(
logits,
top_k=top_k,
top_p=top_p,
temperature=temperature,
generator=generator,
**(dict(temperature=temperature)
if temperature is not None else dict()))
)
case ("temperature", temperature):
tokens, softmax = temperature_sampling_batch(
logits,
temperature=temperature,
generator=generator,
)
case ("greedy", None):
tokens, softmax = greedy_search_sampling_batch(
logits, softmax_indices=softmax_indices)
@ -1070,7 +1130,11 @@ class TorchSampler(Sampler):
def _process_draft_tokens_rejection_sampling(
self, request: LlmRequest, new_tokens: list[list[list[int]]],
new_tokens_tensor: torch.Tensor) -> int:
sampling_strategy = _request_strategy(request)
# FIXME: Passing a dummy vocab_size could result in unnecessary
# filtering of vocab_size logits, out of vocab_size in
# total. The 'sample' below should generally be avoided
# by retaining the draft_probs during drafting (TRTLLM-7772).
sampling_strategy = _request_strategy(request, vocab_size=2**31)
generator = self.get_generator(request.py_draft_logits.device)
_, draft_probs = sample(sampling_strategy,
request.py_draft_logits,
@ -1329,7 +1393,7 @@ class TorchSampler(Sampler):
dim=-1)
requests_by_strategy = _group_requests_by_sampling_strategy(
requests, pin_memory=True)
requests, pin_memory=True, vocab_size=logits_cuda.size(1))
generator_cuda = self.get_generator(cuda_device)
# FIXME: This check should/could be performed in ModelDrafter.prepare_draft_tokens
@ -1706,8 +1770,17 @@ class TorchSampler(Sampler):
@override
def should_provide_draft_probs(self, request: LlmRequest) -> bool:
params = _request_get_sampling_params(request)
temperature = params.temperature
top_p = params.top_p
top_k = params.top_k
# Do not request draft probs when sampling is greedy.
return _request_strategy(request) is not GREEDY
return not SamplingParams.params_imply_greedy_decoding(
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
class Algorithms:

View File

@ -154,13 +154,25 @@ class SamplingParams:
best_of (int, optional): Number of sequences to consider for best output. Defaults to None.
use_beam_search (bool): Whether to use beam search. Defaults to False.
top_k (int, optional): Controls number of logits to sample from. None means using C++ runtime default 0, i.e., all logits. Defaults to None.
top_p (float, optional): Controls the top-P probability to sample from. None means using C++ runtime default 0.f. Defaults to None.
top_k (int, optional): Controls number of logits to sample from. Can assume non-negative values, where 0 means 'all logits'. Defaults to None.
The value None is treated as "not specified" in the following.
If neither temperature, top_p, nor top_k are specified, sampling is greedy.
If temperature > 0 and/or top_p < 1 are specified, sampling will proceed accordingly and top_k will default to top_k = 0.
Setting top_k = 1 results in greedy sampling.
top_p (float, optional): Controls the top-P probability to sample from. Can have values between 0 and 1. Defaults to None.
The value None is treated as "not specified" in the following.
If neither temperature, top_p, nor top_k are specified, sampling is greedy.
If temperature > 0 and/or top_k > 1 are specified, sampling will proceed accordingly and top_p will default to top_p = 1.
Setting top_p = 0 should result in greedy sampling, but is currently disallowed in the backend.
top_p_min (float, optional): Controls decay in the top-P algorithm. topPMin is lower-bound. None means using C++ runtime default 1.e-6. Defaults to None.
top_p_reset_ids (int, optional): Controls decay in the top-P algorithm. Indicates where to reset the decay. None means using C++ runtime default 1. Defaults to None.
top_p_decay (float, optional): Controls decay in the top-P algorithm. The decay value. None means using C++ runtime default 1.f. Defaults to None.
seed (int, optional): Controls the random seed used by the random number generator in sampling. None means using C++ runtime default 0. Defaults to None.
temperature (float, optional): Controls the modulation of logits when sampling new tokens. It can have values > 0.f. None means using C++ runtime default 1.0f. Defaults to None.
temperature (float, optional): Controls the modulation of logits when sampling new tokens. It can have values >= 0.f. Defaults to None.
The value None is treated as "not specified" in the following.
If neither temperature, top_p, nor top_k are specified, sampling is greedy.
If top_p < 1 and/or top_k > 1 are specified, sampling will proceed accordingly and temperature will default to temperature = 1.
Setting temperature = 0 results in greedy sampling.
min_tokens (int, optional): Lower bound on the number of tokens to generate. Values < 1 have no effect. None means using C++ runtime default 1. Defaults to None.
beam_search_diversity_rate (float, optional): Used to penalize tokens based on how often they appear in the sequence. It can have any value > 0.f. Values < 1.f encourages repetition, values > 1.f discourages it. None means using C++ runtime default 1.f. Defaults to None.
repetition_penalty (float, optional): Used to penalize tokens based on how often they appear in the sequence. It can have any value > 0.f. Values < 1.f encourages repetition, values > 1.f discourages it. None means using C++ runtime default 1.f. Defaults to None.
@ -296,11 +308,19 @@ class SamplingParams:
For instance, while the greedy decoding with n > 1 is capable in the
Executor class of C++ runtime, the LLM API disallows such combination.
"""
if self.best_of < self.n:
if self.top_p is not None and (self.top_p < 0 or self.top_p > 1):
raise ValueError(f"require 0 <= top_p <= 1, got top_p={self.top_p}")
if self.top_k is not None and self.top_k < 0:
raise ValueError(f"require top_k >= 0, got top_k={self.top_k}")
if self.temperature is not None and self.temperature < 0:
raise ValueError(f"require temperature >= 0, got temperature={self.temperature}")
if self.best_of is not None and self.best_of < self.n:
raise ValueError(f"best_of ({self.best_of}) cannot be less than n ({self.n})")
if (
self.best_of > 1
self.best_of is not None
and self.best_of > 1
and self._greedy_decoding
and not os.environ.get("TLLM_ALLOW_N_GREEDY_DECODING", None)
):
@ -324,12 +344,25 @@ class SamplingParams:
self.logprobs = self.logprobs and int(self.logprobs)
self.prompt_logprobs = self.prompt_logprobs and int(self.prompt_logprobs)
# NB: Static, because downstream code only holds instances of
# bindings.SamplingConfig (not SamplingParams).
@staticmethod
def params_imply_greedy_decoding(
*, temperature: Optional[float], top_p: Optional[float], top_k: Optional[int]
):
return (
(temperature is None and top_p is None and top_k is None)
or top_k == 1
or top_p == 0.0
or temperature == 0
)
@property
def _greedy_decoding(self) -> bool:
return (
not self.use_beam_search
and (self.top_k is None or self.top_k == 1)
and (self.top_p is None or self.top_p == 0.0)
return not self.use_beam_search and self.params_imply_greedy_decoding(
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
)
@property

View File

@ -14,6 +14,7 @@ l0_a10:
backend: pytorch
tests:
# ------------- PyTorch tests ---------------
- unittest/_torch/sampler/test_torch_sampler.py
- unittest/_torch/modeling/test_modeling_mistral.py
- unittest/_torch/modeling/test_modeling_pixtral.py
# NOTE: this is a CPU-only test, but we do not have a dedicated job for this (and therefore no

View File

@ -0,0 +1,303 @@
from itertools import product
from typing import Optional, cast
import pytest
from utils.util import force_ampere
from tensorrt_llm._torch.pyexecutor.sampler import (
GREEDY,
LlmRequest,
TorchSampler,
_request_strategy,
)
from tensorrt_llm.bindings import SamplingConfig
from tensorrt_llm.sampling_params import SamplingParams
@force_ampere
class TestStrategySelection:
VOCAB_SIZE = 1000
TOP_K_VALS = [None, 0, 1, 42, 1000]
TOP_P_VALS = [None, 0, 0.42, 1]
TEMPERATURE_VALS = [None, 0, 1.42]
# For non-greedy sampling, the following choices have no effect.
TOP_P_NEUTRAL_VALS = [None, 1]
TOP_K_NEUTRAL_VALS = [None, 0, VOCAB_SIZE]
TEMPERATURE_NEUTRAL_VALS = [None, 1]
TEMPERATURE_NOT_GREEDY = [0.42] + [t for t in TEMPERATURE_NEUTRAL_VALS if t is not None]
class MockLlmRequest:
sampling_config: SamplingConfig
def _check_params(self, params: SamplingParams):
# cf. description of 'top_p' in doc-string of SamplingParams and
# test_top_p_0_disallowed below.
if params.top_p == 0:
pytest.skip("top_p = 0 disallowed by tensorrt_llm::executor::SamplingConfig")
def test_top_p_0_disallowed(self):
# If this xpasses, update _check_params and doc-string of SamplingParams.
params = SamplingParams(top_p=0)
pytest.xfail("top_p = 0 disallowed by tensorrt_llm::executor::SamplingConfig")
params._get_sampling_config()
def _build_mock_llm_request(self, params: SamplingParams) -> LlmRequest:
request = self.MockLlmRequest()
request.sampling_config = SamplingConfig(params._get_sampling_config())
return cast(LlmRequest, request)
def test_defaults(self):
# NB: The code in _request_strategy relies on the default values below.
default_params = SamplingParams()
assert default_params.top_k is None
assert default_params.top_p is None
assert default_params.temperature is None
def test_defaults_config(self):
# NB: The code in _request_strategy relies on the default values below.
default_config = SamplingParams()._get_sampling_config()
assert default_config.top_k is None
assert default_config.top_p is None
assert default_config.temperature is None
def test_defaults_request(self):
# NB: The code in _request_strategy relies on the default values below.
request = self._build_mock_llm_request(SamplingParams())
default_config = request.sampling_config
assert default_config.top_k is None
assert default_config.top_p is None
assert default_config.temperature is None
def test_default_is_greedy(self):
request = self._build_mock_llm_request(SamplingParams())
assert _request_strategy(request, vocab_size=self.VOCAB_SIZE) is GREEDY
@pytest.mark.parametrize(
"top_p, top_k",
[
pytest.param(top_p, top_k)
# https://stackoverflow.com/a/75421799, does not work with nested loops
for (top_k, top_p) in product(TOP_K_VALS, TOP_P_VALS)
],
)
def test_temperature_0_is_greedy(self, top_p: Optional[float], top_k: Optional[int]):
params = SamplingParams(temperature=0, top_p=top_p, top_k=top_k)
self._check_params(params)
request = self._build_mock_llm_request(params)
assert _request_strategy(request, vocab_size=self.VOCAB_SIZE) is GREEDY
@pytest.mark.parametrize(
"temperature, top_k",
[
pytest.param(temperature, top_k)
# https://stackoverflow.com/a/75421799, does not work with nested loops
for (temperature, top_k) in product(TEMPERATURE_VALS, TOP_K_VALS)
],
)
def test_top_p_0_is_greedy(self, temperature: Optional[float], top_k: Optional[int]):
params = SamplingParams(top_p=0, temperature=temperature, top_k=top_k)
self._check_params(params)
request = self._build_mock_llm_request(params)
assert _request_strategy(request, vocab_size=self.VOCAB_SIZE) is GREEDY
@pytest.mark.parametrize(
"temperature, top_p",
[
pytest.param(temperature, top_p)
# https://stackoverflow.com/a/75421799, does not work with nested loops
for (temperature, top_p) in product(TEMPERATURE_VALS, TOP_P_VALS)
],
)
def test_top_k_1_is_greedy(self, temperature: Optional[float], top_p: Optional[float]):
params = SamplingParams(top_p=top_p, temperature=temperature, top_k=1)
self._check_params(params)
request = self._build_mock_llm_request(params)
assert _request_strategy(request, vocab_size=self.VOCAB_SIZE) is GREEDY
@pytest.mark.parametrize(
"temperature, trivial_top_p, trivial_top_k",
[
pytest.param(temperature, top_p, top_k)
# https://stackoverflow.com/a/75421799, does not work with nested loops
for (temperature, top_k, top_p) in product(
TEMPERATURE_NOT_GREEDY, TOP_K_NEUTRAL_VALS, TOP_P_NEUTRAL_VALS
)
],
)
def test_temperature_only(
self, temperature: float, trivial_top_p: Optional[float], trivial_top_k: Optional[int]
):
params = SamplingParams(temperature=temperature, top_p=trivial_top_p, top_k=trivial_top_k)
self._check_params(params)
request = self._build_mock_llm_request(params)
strat = _request_strategy(request, vocab_size=self.VOCAB_SIZE)
assert len(strat) == 2
assert strat[0] == "temperature"
assert strat[1] == pytest.approx(temperature)
@pytest.mark.parametrize(
"trivial_temperature, trivial_top_k",
[
pytest.param(temperature, top_k)
# https://stackoverflow.com/a/75421799, does not work with nested loops
for (temperature, top_k) in product(TEMPERATURE_NEUTRAL_VALS, TOP_K_NEUTRAL_VALS)
],
)
def test_top_p_only(self, trivial_temperature: Optional[float], trivial_top_k: Optional[int]):
params = SamplingParams(top_p=0.42, temperature=trivial_temperature, top_k=trivial_top_k)
self._check_params(params)
request = self._build_mock_llm_request(params)
strat = _request_strategy(request, vocab_size=self.VOCAB_SIZE)
assert len(strat) == 3
assert strat[0] == "top_p"
assert strat[1] == pytest.approx(0.42)
assert strat[2] == pytest.approx(1.0)
@pytest.mark.parametrize(
"trivial_top_k",
[
pytest.param(top_k)
for top_k in TOP_K_NEUTRAL_VALS # https://stackoverflow.com/a/75421799
],
)
def test_top_p_with_temperature(self, trivial_top_k: Optional[int]):
params = SamplingParams(top_p=0.42, temperature=0.9, top_k=trivial_top_k)
self._check_params(params)
request = self._build_mock_llm_request(params)
strat = _request_strategy(request, vocab_size=self.VOCAB_SIZE)
assert len(strat) == 3
assert strat[0] == "top_p"
assert strat[1] == pytest.approx(0.42)
assert strat[2] == pytest.approx(0.9)
@pytest.mark.parametrize(
"trivial_temperature, trivial_top_p",
[
pytest.param(temperature, top_p)
# https://stackoverflow.com/a/75421799, does not work with nested loops
for (temperature, top_p) in product(TEMPERATURE_NEUTRAL_VALS, TOP_P_NEUTRAL_VALS)
],
)
def test_top_k_only(self, trivial_temperature: Optional[float], trivial_top_p: Optional[float]):
params = SamplingParams(top_k=42, temperature=trivial_temperature, top_p=trivial_top_p)
self._check_params(params)
request = self._build_mock_llm_request(params)
strat = _request_strategy(request, vocab_size=self.VOCAB_SIZE)
assert len(strat) == 3
assert strat[0] == "top_k"
assert strat[1] == 42
assert strat[2] == pytest.approx(1.0)
@pytest.mark.parametrize(
"trivial_top_p",
[
pytest.param(top_p)
for top_p in TOP_P_NEUTRAL_VALS # https://stackoverflow.com/a/75421799
],
)
def test_top_k_with_temperature(self, trivial_top_p: Optional[float]):
params = SamplingParams(top_k=42, temperature=0.9, top_p=trivial_top_p)
self._check_params(params)
request = self._build_mock_llm_request(params)
strat = _request_strategy(request, vocab_size=self.VOCAB_SIZE)
assert len(strat) == 3
assert strat[0] == "top_k"
assert strat[1] == 42
assert strat[2] == pytest.approx(0.9)
@pytest.mark.parametrize(
"trivial_temperature",
[
pytest.param(temperature)
for temperature in TEMPERATURE_NEUTRAL_VALS # https://stackoverflow.com/a/75421799
],
)
def test_top_k_top_p(self, trivial_temperature: Optional[float]):
params = SamplingParams(top_k=42, top_p=0.7, temperature=trivial_temperature)
self._check_params(params)
request = self._build_mock_llm_request(params)
strat = _request_strategy(request, vocab_size=self.VOCAB_SIZE)
assert len(strat) == 4
assert strat[0] == "top_k_top_p"
assert strat[1] == 42
assert strat[2] == pytest.approx(0.7)
assert strat[3] == pytest.approx(1.0)
def test_top_k_top_p_with_temperature(self):
params = SamplingParams(top_k=42, top_p=0.7, temperature=0.9)
self._check_params(params)
request = self._build_mock_llm_request(params)
strat = _request_strategy(request, vocab_size=self.VOCAB_SIZE)
assert len(strat) == 4
assert strat[0] == "top_k_top_p"
assert strat[1] == 42
assert strat[2] == pytest.approx(0.7)
assert strat[3] == pytest.approx(0.9)
def test_param_validation(self):
with pytest.raises(ValueError, match="require temperature >= 0, got temperature=-1"):
SamplingParams(temperature=-1)
with pytest.raises(ValueError, match="require 0 <= top_p <= 1, got top_p=-1"):
SamplingParams(top_p=-1)
with pytest.raises(ValueError, match="require 0 <= top_p <= 1, got top_p=2"):
SamplingParams(top_p=2)
with pytest.raises(ValueError, match="require top_k >= 0, got top_k=-1"):
SamplingParams(top_k=-1)
@pytest.mark.parametrize(
"top_k, top_p",
[
pytest.param(top_k, top_p)
# https://stackoverflow.com/a/75421799, does not work with nested loops
for (top_k, top_p) in product(TOP_K_NEUTRAL_VALS, TOP_P_NEUTRAL_VALS)
if (top_k, top_p) != (None, None)
],
)
def test_trivial_top_k_top_p_not_greedy(self, top_k: Optional[int], top_p: Optional[float]):
params = SamplingParams(top_k=top_k, top_p=top_p)
self._check_params(params)
request = self._build_mock_llm_request(params)
strat = _request_strategy(request, vocab_size=self.VOCAB_SIZE)
assert len(strat) == 2
assert strat[0] == "temperature"
assert strat[1] == pytest.approx(1.0)
@pytest.fixture
def torch_sampler(self) -> TorchSampler:
return TorchSampler(
TorchSampler.Args(
max_seq_len=123,
max_draft_len=3,
max_num_sequences=12,
max_beam_width=1,
max_total_draft_tokens=3,
)
)
@pytest.mark.parametrize(
"temperature, top_p, top_k",
[
pytest.param(temperature, top_p, top_k)
# https://stackoverflow.com/a/75421799, does not work with nested loops
for (temperature, top_p, top_k) in product(TEMPERATURE_VALS, TOP_P_VALS, TOP_K_VALS)
],
)
def test_should_provide_draft_probs_consistency(
self,
temperature: Optional[float],
top_p: Optional[float],
top_k: Optional[int],
torch_sampler: TorchSampler,
):
params = SamplingParams(top_k=top_k, top_p=top_p, temperature=temperature)
self._check_params(params)
request = self._build_mock_llm_request(params)
strat = _request_strategy(request, vocab_size=self.VOCAB_SIZE)
is_greedy = strat is GREEDY
assert torch_sampler.should_provide_draft_probs(request) == (not is_greedy)

View File

@ -23,6 +23,7 @@ def create_llm(model_dir):
enable_chunked_prefill=True,
cuda_graph_config=CudaGraphConfig(),
kv_cache_config=trt_kv_cache_config,
sampler_type="TRTLLMSampler",
max_num_tokens=
128 # Only one request longer than max_num_tokens is required to test chunked prefill
)

View File

@ -18,6 +18,7 @@ import sys
import traceback
from typing import Any
import _pytest.outcomes
import pytest
import torch
import tqdm
@ -65,8 +66,9 @@ def pytest_pyfunc_call(pyfuncitem) -> Any:
return (yield)
# NB: _pytest.outcomes.OutcomeException subclasses BaseException
except BaseException as e:
print(f"TEST RAISED ERROR: {e}")
traceback.print_exception(e)
if not isinstance(e, _pytest.outcomes.Skipped):
print(f"TEST RAISED ERROR: {e}")
traceback.print_exception(e)
raise