[TRTLLM-8436][feat] batched sampling and top-k logprobs improvements (#8398)

Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
This commit is contained in:
mpikulski 2025-10-20 11:15:41 +02:00 committed by GitHub
parent d05079ba4b
commit 97ce0ecefe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 1159 additions and 993 deletions

View File

@ -35,6 +35,8 @@ extend_skip_glob = [
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
"tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py",
"tests/unittest/_torch/sampler/test_torch_sampler.py",
"tensorrt_llm/_torch/pyexecutor/sampler.py",
"tensorrt_llm/_torch/pyexecutor/sampling_utils.py",
]
[tool.yapf]
@ -67,6 +69,8 @@ ignore_patterns = [
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
"tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py",
"tests/unittest/_torch/sampler/test_torch_sampler.py",
"tensorrt_llm/_torch/pyexecutor/sampler.py",
"tensorrt_llm/_torch/pyexecutor/sampling_utils.py",
]
[tool.codespell]
@ -102,6 +106,8 @@ exclude = [
"tests/unittest/_torch/modeling/test_modeling_mistral.py",
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
"tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py",
"tensorrt_llm/_torch/pyexecutor/sampler.py",
"tensorrt_llm/_torch/pyexecutor/sampling_utils.py",
]
@ -147,6 +153,8 @@ include = [
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
"tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py",
"tests/unittest/_torch/sampler/test_torch_sampler.py",
"tensorrt_llm/_torch/pyexecutor/sampler.py",
"tensorrt_llm/_torch/pyexecutor/sampling_utils.py",
]
exclude = [
"**3rdparty/**",

View File

@ -13,7 +13,7 @@ from ....executor.request import GenerationRequest
from ....executor.result import CompletionOutput, GenerationResult
from ....inputs.multimodal import MultimodalParams
from ....sampling_params import SamplingParams
from ...pyexecutor.sampler import greedy_search_sampling_batch, top_k_sampling_batch
from ...pyexecutor.sampling_utils import greedy_search_sampling_batch, top_k_sampling_batch
from ..distributed import common as dist_ad
from ..utils.logger import ad_logger
from .ad_executor import ADEngine

View File

@ -1204,13 +1204,12 @@ class PyExecutor:
self._kv_connector_terminate_requests()
if self.enable_iter_perf_stats:
if self.enable_iter_perf_stats and sample_state is not None:
iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[
'num_ctx_tokens']
self._process_iter_stats(
finished_requests, self.active_requests,
BatchState(sample_state=SampleState(
scheduled_requests=scheduled_batch),
BatchState(sample_state=sample_state,
iter_stats=iter_stats,
iter_start_time=iter_start_time))

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,406 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper functions for sampling.
Code in this module should operate on logits and probs, without
referring to types like LlmRequest.
"""
import abc
import sys
from dataclasses import dataclass
from typing import Generic, Literal, Optional, TypeAlias, TypeVar, cast
import torch
from tensorrt_llm.sampling_params import SamplingParams
if sys.version_info[:2] >= (3, 12):
from typing import override
else:
from typing_extensions import override
TemperatureOnly = tuple[Literal["temperature"], float]
TopK = tuple[Literal["top_k"], int, float]
TopP = tuple[Literal["top_p"], float, float]
TopKTopP = tuple[Literal["top_k_top_p"], int, float, float]
Greedy = tuple[Literal["greedy"], None]
GREEDY: Greedy = ("greedy", None)
Strategy = TopK | TopP | Greedy | TopKTopP | TemperatureOnly
@dataclass(frozen=True, kw_only=True)
class UtilsSamplingParams:
"""Subset of tensorrt_llm::runtime::SamplingConfig supported by sampling_utils."""
temperature: Optional[float]
top_p: Optional[float]
top_k: Optional[int]
def resolve_sampling_strategy(params: UtilsSamplingParams, *, vocab_size: int) -> Strategy:
# The semantics are specified in the doc-string of SamplingParams
temperature = params.temperature
top_p = params.top_p
top_k = params.top_k
if SamplingParams.params_imply_greedy_decoding(
temperature=temperature,
top_p=top_p,
top_k=top_k,
):
return GREEDY
# --- resolving default values
# NB: not greedy, hence temperature != 0 if specified
temperature = temperature or 1.0
# NB: not greedy, hence top_p != 0 if specified
top_p = top_p or 1.0
# NB: not greedy, hence top_k != 1 if specified
# (0 and vocab_size are equivalent)
top_k = top_k or vocab_size
assert top_k > 1, "non-greedy sampling requires valid top_k"
need_top_k = top_k < vocab_size
assert top_p > 0, "non-greedy sampling requires valid top_p"
need_top_p = top_p < 1
if need_top_p:
if need_top_k:
return ("top_k_top_p", top_k, top_p, temperature)
return ("top_p", top_p, temperature)
if need_top_k:
return ("top_k", top_k, temperature)
return ("temperature", temperature)
def top_k_sampling_batch(
logits,
*,
top_k: int,
temperature: float,
generator: Optional[torch.Generator] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
# NB: To be replaced by a more efficient implementation.
return top_k_top_p_sampling_batch(
logits,
top_k=top_k,
temperature=temperature,
generator=generator,
top_p=1,
)
def top_p_sampling_batch(
logits: torch.Tensor,
*,
top_p: float,
temperature: float,
generator: Optional[torch.Generator] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
# NB: To be replaced by a more efficient implementation.
return top_k_top_p_sampling_batch(
logits,
top_p=top_p,
top_k=logits.size(1),
temperature=temperature,
generator=generator,
)
def temperature_sampling_batch(
logits: torch.Tensor,
*,
temperature: float,
generator: Optional[torch.Generator] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
# NB: To be replaced by a more efficient implementation.
return top_k_top_p_sampling_batch(
logits,
top_p=1,
top_k=logits.size(1),
temperature=temperature,
generator=generator,
)
def top_k_top_p_sampling_batch(
logits: torch.Tensor,
*,
top_k: int,
top_p: float,
temperature: float,
generator: Optional[torch.Generator] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
logits_dim = logits.dim()
assert logits_dim == 2, "logits should be 2D: [batch_size, vocab_size]"
assert temperature > 0, "non-greedy sampling requires valid temperature"
logits = logits / max(temperature, 1e-5)
batch_size, vocab_size = logits.size()
assert top_k > 1, "non-greedy sampling requires valid top_k"
need_top_k = top_k < vocab_size
assert top_p > 0, "non-greedy sampling requires valid top_p"
need_top_p = top_p < 1
# top-K: mask out logits not belonging to the top-K for each sample
if need_top_k:
values, _ = torch.topk(logits, top_k, dim=-1)
min_values = values[:, -1].unsqueeze(-1).expand(batch_size, vocab_size)
# set the logits who is less than first top_k logits to -inf
logits = torch.where(logits < min_values, torch.full_like(logits, float("-inf")), logits)
# top-p: mask out logits outside the nucleus
if need_top_p:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
# compute cumulative probability distribution of each sample
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
# get the location of top_p
# NB: Currently selecting the smallest index with cumulative_probs > top_p.
# Thus, top_p -> 0 resembles greedy; agreement requires torch.sort(..., stable=True).
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = 0
# set the logits to -inf for token indices outside top_p
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
logits = logits.masked_fill(indices_to_remove, float("-inf"))
# compute probability distribution
softmax = torch.softmax(logits, dim=-1)
# sample from the distribution and generate result of [batch_size, 1]
next_tokens = torch.multinomial(softmax, num_samples=1, generator=generator).squeeze(-1)
return next_tokens, softmax
def greedy_search_sampling_batch(
logits,
*,
return_probs: bool = True,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
next_tokens = torch.argmax(logits, dim=-1)
softmax: Optional[torch.Tensor] = None
if return_probs:
softmax = torch.softmax(logits, dim=-1)
return next_tokens, softmax
def get_rejected_indices(
draft_probs: torch.Tensor,
target_probs: torch.Tensor,
generator: torch.Generator,
draft_tokens: list[int],
) -> torch.Tensor:
# NB: ModelDrafter._pad_to_max_draft_tokens pads draft_tokens, but
# not draft_probs. Relying on shape of draft_probs here.
num_draft_tokens = draft_probs.size(0)
draft_tokens = draft_tokens[:num_draft_tokens]
# NB: torch.arange is needed to enable "advanced indexing",
# cf. https://numpy.org/devdocs/user/basics.indexing.html#integer-array-indexing
token_idx = torch.arange(num_draft_tokens, dtype=torch.int32, device=generator.device)
draft_tokens_cuda = torch.tensor(draft_tokens, dtype=torch.int32, pin_memory=True).to(
device=generator.device, non_blocking=True
)
p = draft_probs[token_idx, draft_tokens_cuda]
q = target_probs.squeeze(0)[token_idx, draft_tokens_cuda]
accept_probs = torch.minimum(torch.ones((), device=generator.device, dtype=q.dtype), q / p)
# Use deterministic random generation for multi-GPU consistency
rejected_indices = (
torch.rand(accept_probs.shape, generator=generator, device=accept_probs.device)
> accept_probs
).nonzero()
return rejected_indices
def sample_rejected(
draft_probs: torch.Tensor,
target_probs: torch.Tensor,
generator: torch.Generator,
num_accepted: int,
) -> int:
last_draft = draft_probs[num_accepted]
last_target = target_probs[num_accepted]
new = last_target - last_draft
new = torch.where(new > 0, new, 0.0)
new_token = torch.multinomial(new, num_samples=1, generator=generator).squeeze(-1)
return cast(int, new_token.item())
def sample(
strategy: Strategy,
logits: torch.Tensor,
*,
generator: Optional[torch.Generator] = None,
return_probs: bool = True,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
match strategy:
case ("top_k", top_k, temperature):
tokens, softmax = top_k_sampling_batch(
logits, top_k=top_k, temperature=temperature, generator=generator
)
case ("top_p", top_p, temperature):
tokens, softmax = top_p_sampling_batch(
logits,
top_p=top_p,
generator=generator,
temperature=temperature,
)
case ("top_k_top_p", top_k, top_p, temperature):
tokens, softmax = top_k_top_p_sampling_batch(
logits,
top_k=top_k,
top_p=top_p,
temperature=temperature,
generator=generator,
)
case ("temperature", temperature):
tokens, softmax = temperature_sampling_batch(
logits,
temperature=temperature,
generator=generator,
)
case ("greedy", None):
tokens, softmax = greedy_search_sampling_batch(logits, return_probs=return_probs)
return tokens, softmax
GenericStrategyKeyType = TypeVar("GenericStrategyKeyType")
class GroupedStrategySampler(Generic[GenericStrategyKeyType], abc.ABC):
@staticmethod
@abc.abstractmethod
def strategy_grouping_key(strategy: Strategy) -> GenericStrategyKeyType:
raise NotImplementedError
@staticmethod
@abc.abstractmethod
def sample_grouped_strategies(
group_key: GenericStrategyKeyType,
strategies: list[Strategy],
logits: torch.Tensor,
*,
group_logit_indices: Optional[torch.Tensor] = None,
generator: Optional[torch.Generator] = None,
return_probs: bool,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
raise NotImplementedError
class SimpleGroupedStrategySampler(GroupedStrategySampler[Strategy]):
STRATEGY_KEY_TYPE: TypeAlias = Strategy
@override
@staticmethod
def strategy_grouping_key(strategy: Strategy) -> STRATEGY_KEY_TYPE:
return strategy
@override
@staticmethod
def sample_grouped_strategies(
group_key: STRATEGY_KEY_TYPE,
strategies: list[Strategy],
logits: torch.Tensor,
*,
group_logit_indices: Optional[torch.Tensor] = None,
generator: Optional[torch.Generator] = None,
return_probs: bool,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if group_logit_indices is None:
assert logits.size(0) == len(strategies)
else:
logits = logits[group_logit_indices]
assert all(strategy == group_key for strategy in strategies), "group must be consistent"
return sample(
group_key,
logits,
generator=generator,
return_probs=return_probs,
)
# Inspired by https://github.com/pytorch/pytorch/issues/80577; note also the
# suggestion to consider torch.nested.
def torch_multi_arange(
ends: torch.Tensor,
*,
starts: Optional[torch.Tensor] = None,
steps: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Efficiently compute torch.cat([torch.arange(b, e, d) for b, e, d in zip(starts, ends, steps)]).
Starts, ends, steps need to share dtype and shape. Invalid ranges like range(1, 2, -1) are
silently discarded. 'steps' defaults to 1 and 'starts' defaults to 0.
"""
if steps is not None:
assert ends.dtype == steps.dtype
assert ends.shape == steps.shape
if starts is not None:
assert ends.dtype == starts.dtype
assert ends.shape == starts.shape
# This algorithm combines torch.repeat_interleaved() and torch.cumsum() to
# construct the result.
#
# 1. Given N ranges (characterized by starts, ends, steps), construct a sequence
# of 2N numbers, in which the non-overlapping pairs of consecutive numbers
# correspond to the ranges. For a given range, the pair (a, b) is chosen such
# that upon torch.cumsum() application 'a' turns the last element of the
# preceding range into the start element for the current range and 'b' is
# simply the step size for the current range.
#
repeats = ends # number of elements in each range
if starts is not None:
repeats = repeats.clone()
repeats -= starts
if steps is not None:
repeats = (repeats + steps - 1).div(steps, rounding_mode="floor")
repeats = repeats.clip(0) # ignore invalid ranges
range_ends = repeats - 1 # last element in each range
if steps is not None:
range_ends *= steps
if starts is not None:
range_ends += starts
prev_range_ends = range_ends.roll(1) # last element in preceding range (or 0)
prev_range_ends[0] = 0
ones = (
torch.tensor(1, dtype=ends.dtype, pin_memory=True)
.to(device=ends.device, non_blocking=True)
.broadcast_to(ends.shape)
)
if steps is None:
steps = ones
jumps = -prev_range_ends # delta from one range to the next
if starts is not None:
jumps += starts
seq = torch.cat((jumps.unsqueeze(-1), steps.unsqueeze(-1)), dim=1).view(-1)
#
# 2. Construct output via torch.repeat_interleave() and torch.cumsum()
seq_repeats = torch.cat((ones.unsqueeze(-1), (repeats - 1).unsqueeze(-1)), dim=1).view(-1)
seq = seq.repeat_interleave(seq_repeats)
seq = seq.cumsum(0)
return seq

View File

@ -256,7 +256,7 @@ class MTPSampler(TorchSampler):
assert isinstance(state, SampleStateMTP)
state.sampler_event.synchronize()
new_tokens = state.host.new_tokens.tolist()
new_tokens = state.host.new_tokens
new_tokens_lens_list = state.host.new_tokens_lens.tolist()
next_draft_tokens_list = state.host.next_draft_tokens.tolist()
beam_idx = self.BEAM

View File

@ -1,3 +1,17 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from itertools import product
from typing import Optional, cast
@ -37,10 +51,10 @@ class TestStrategySelection:
if params.top_p == 0:
pytest.skip("top_p = 0 disallowed by tensorrt_llm::executor::SamplingConfig")
# If this xpasses, update _check_params and doc-string of SamplingParams.
@pytest.mark.xfail(reason="top_p = 0 disallowed by tensorrt_llm::executor::SamplingConfig")
def test_top_p_0_disallowed(self):
# If this xpasses, update _check_params and doc-string of SamplingParams.
params = SamplingParams(top_p=0)
pytest.xfail("top_p = 0 disallowed by tensorrt_llm::executor::SamplingConfig")
params._get_sampling_config()
def _build_mock_llm_request(self, params: SamplingParams) -> LlmRequest:

View File

@ -33,7 +33,7 @@ def test_get_rejected_indices():
sampled_tokens.append(draft_tokens[0])
else:
sampled_tokens.append(
sample_rejected(draft_probs, target_probs, generator, 0).item())
sample_rejected(draft_probs, target_probs, generator, 0))
sampled_regular.append(
torch.multinomial(target_probs[0],
num_samples=1,