[TRTLLM-10791][feat] TorchSampler general host time optimization (#11141)

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
This commit is contained in:
Yukun He 2026-02-13 01:05:58 +08:00 committed by GitHub
parent 4b2b1d146b
commit cb1d8d130f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 253 additions and 75 deletions

View File

@ -1,14 +1,10 @@
from copy import copy, deepcopy
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union
import torch
import tensorrt_llm.bindings
if TYPE_CHECKING:
from tensorrt_llm._torch.pyexecutor.sampler import Strategy
from tensorrt_llm._torch.shared_tensor import SharedTensorContainer
from tensorrt_llm.bindings import executor as tllm_executor
from tensorrt_llm.executor.result import TokenLogprobs
@ -676,8 +672,6 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
additional_outputs=additional_outputs)
self.child_requests = []
self._py_sampling_strategy: "Strategy | None" = None
self._py_embedding_bias_1d: Optional[torch.Tensor] = None
if hasattr(self, 'embedding_bias') and self.embedding_bias is not None:
# Pre-squeeze to 1D if needed (remove batch dimension)

View File

@ -15,7 +15,7 @@
import enum
import sys
from abc import ABC, abstractmethod
from collections import defaultdict, namedtuple
from collections import defaultdict
from collections.abc import Iterable
from concurrent import futures
from dataclasses import dataclass
@ -369,10 +369,6 @@ def _get_max_beam_width(request: LlmRequest) -> int:
return max_beam_width
def _request_sampling_params_cachable(params: UtilsSamplingParams) -> bool:
return not params.use_beam_search
def _request_get_sampling_params(request: LlmRequest) -> UtilsSamplingParams:
sampling_config = request.sampling_config
temperature = _unwrap_singleton(cast(Optional[list[float]], sampling_config.temperature))
@ -393,16 +389,13 @@ def _request_get_sampling_params(request: LlmRequest) -> UtilsSamplingParams:
def _request_strategy(request: LlmRequest, *, vocab_size: int) -> Strategy:
# We try to cache the resolved strategy on the request object, as it's not cheap enough to
# resolve it on every iteration.
if request._py_sampling_strategy is not None:
return request._py_sampling_strategy
"""Resolve the sampling strategy for a request.
Note: Callers inside _group_requests_by_strategy_key benefit from store.strategies
caching, which ensures this function is called at most once per request per slot.
"""
params = _request_get_sampling_params(request)
sampling_strategy = resolve_sampling_strategy(params, vocab_size=vocab_size)
if _request_sampling_params_cachable(params):
request._py_sampling_strategy = resolve_sampling_strategy(params, vocab_size=vocab_size)
return sampling_strategy
return resolve_sampling_strategy(params, vocab_size=vocab_size)
def _group_requests_by_strategy_key(
@ -410,63 +403,191 @@ def _group_requests_by_strategy_key(
*,
strategy_to_key: Callable[[Strategy], GenericStrategyKeyType],
pin_memory: bool = False,
store: "TorchSampler.Store",
seq_slots: torch.Tensor,
vocab_size: int,
) -> dict[RequestGroupKey[GenericStrategyKeyType], RequestGroupValue]:
# NB: Client code relies on request indices in returned torch.Tensor being ordered
# by ascending "requests" index.
RequestGroupValueBuilder = namedtuple(
"RequestGroupValueBuilder",
[
"indices",
"strategies",
"speculation_needs_probs_list",
"need_processed_logprobs_list",
"need_raw_logprobs_list",
],
"""
Optimized implementation with vectorized boolean operations and efficient grouping.
NB: Client code relies on request indices in returned torch.Tensor being sorted.
"""
# Convert to list for efficient indexing
requests_list = list(requests) if not isinstance(requests, list) else requests
n = len(requests_list)
if n == 0:
return {}
assert not seq_slots.is_cuda, "seq_slots is expected to be a host tensor"
seq_slots_list = seq_slots.tolist()
# Get strategies from cache, only recomputing for slots that need it.
# Recompute is needed for:
# - Uncached slots (strategy is None) — recorded in store.slots_needing_recompute
# - Beam search (beam_width_in changes) — kept in slots_needing_recompute permanently
# - Speculative decoding (draft_tokens can change) — checked for non-greedy slots only
# Build strategies from cache in one shot (C-level list comprehension, ~50ns/elem)
s_strategies = store.strategies
strategies = [s_strategies[slot] for slot in seq_slots_list]
# Build slot→request_index mapping for targeted access
slot_to_idx = {slot: i for i, slot in enumerate(seq_slots_list)}
active_slots = set(slot_to_idx)
# 1) Slots pre-recorded for recompute (context-phase or beam search)
recompute_batch_slots = store.slots_needing_recompute & active_slots
# 2) Non-greedy slots where draft-token status may have changed
# (For greedy: current_has_draft is always False, matching cached, so never stale)
draft_check_slots = (store.non_greedy_slots & active_slots) - recompute_batch_slots
for slot in draft_check_slots:
i = slot_to_idx[slot]
has_draft = bool(requests_list[i].py_draft_tokens)
if store.speculation_needs_probs[slot] != has_draft:
# Draft-token status changed — only update the affected flags.
# The strategy itself doesn't depend on draft tokens (only on sampling params).
store.speculation_needs_probs[slot] = has_draft
store.needs_probs[slot] = has_draft or store.need_processed[slot]
# 3) Full recompute for the pre-recorded slots.
# Every slot with a None strategy must already be in slots_needing_recompute
# (populated by setup_sampler_step when a new request arrives).
assert None not in strategies or all(
seq_slots_list[i] in recompute_batch_slots for i in range(n) if strategies[i] is None
), (
"Found slots with uncached strategies not registered in slots_needing_recompute. "
"Ensure setup_sampler_step is called before sample_async for new requests."
)
group_dict: dict[RequestGroupKey[GenericStrategyKeyType], RequestGroupValueBuilder] = (
defaultdict(lambda: RequestGroupValueBuilder([], [], [], [], []))
for slot in recompute_batch_slots:
i = slot_to_idx[slot]
request = requests_list[i]
has_draft_tokens = bool(request.py_draft_tokens)
strategy = _request_strategy(request, vocab_size=vocab_size)
store.strategies[slot] = strategy
strategies[i] = strategy
is_greedy = strategy == GREEDY
store.speculation_needs_probs[slot] = has_draft_tokens and not is_greedy
store.need_processed[slot] = (
request.py_logprobs_mode == LogprobMode.PROCESSED and request.return_log_probs
)
store.need_raw[slot] = (
request.py_logprobs_mode == LogprobMode.RAW and request.return_log_probs
)
store.needs_probs[slot] = store.speculation_needs_probs[slot] or store.need_processed[slot]
# Track non-greedy slots for future draft-token checks
if is_greedy:
store.non_greedy_slots.discard(slot)
else:
store.non_greedy_slots.add(slot)
# Keep beam-search slots in the recompute set (they always need it);
# remove everything else (strategy is now cached).
if not store.uses_beam_search[slot]:
store.slots_needing_recompute.discard(slot)
# Gather flags using list comprehension (faster than append in loop)
needs_probs = torch.tensor(
[store.needs_probs[slot] for slot in seq_slots_list], dtype=torch.bool, device="cpu"
)
speculation_needs_probs = torch.tensor(
[store.speculation_needs_probs[slot] for slot in seq_slots_list],
dtype=torch.bool,
device="cpu",
)
need_processed = torch.tensor(
[store.need_processed[slot] for slot in seq_slots_list], dtype=torch.bool, device="cpu"
)
need_raw = torch.tensor(
[store.need_raw[slot] for slot in seq_slots_list], dtype=torch.bool, device="cpu"
)
# Build strategy ID mapping for vectorized comparison (all on CPU).
# NB: set() does not preserve insertion order, so we use dict.fromkeys() to deduplicate while preserving order.
unique_strategies = list(dict.fromkeys(strategies))
strategy_to_id = {s: idx for idx, s in enumerate(unique_strategies)}
strategy_ids = torch.tensor(
[strategy_to_id[s] for s in strategies], dtype=torch.int32, device="cpu"
)
for req_index, req in enumerate(requests):
strategy = _request_strategy(req, vocab_size=vocab_size)
speculation_needs_probs = (
# NB: This criterion needs to be consistent with the gating of rejection sampling in
# process_draft_tokens.
TorchSampler._speculation_could_use_rejection_sampling(req, strategy)
# Pre-allocate group_ids array
group_ids = torch.empty(n, dtype=torch.int32, device="cpu")
_next_gid = 0
def _provision_gid() -> int:
nonlocal _next_gid
gid = _next_gid
_next_gid += 1
return gid
unique_keys: defaultdict[tuple, int] = defaultdict(_provision_gid)
# Vectorized assignment: loop over unique combinations instead of all requests
for sid, strategy in enumerate(unique_strategies):
strat_mask = strategy_ids == sid
for needs_probs_val in (False, True):
# Vectorized mask for this (strategy, needs_probs) group
mask = strat_mask & (needs_probs if needs_probs_val else ~needs_probs)
if torch.any(mask):
strategy_key = strategy_to_key(strategy) # Called once per group!
key = (strategy_key, needs_probs_val)
group_ids[mask] = unique_keys[key] # Vectorized assignment
# Efficient grouping using sort
sorted_group_ids, sorted_order = torch.sort(group_ids, stable=True)
# Use prepend to detect a "change" at position 0, giving us group_starts directly
group_starts = torch.nonzero(
torch.diff(sorted_group_ids, prepend=torch.tensor([-1], device="cpu")) != 0
).squeeze(1)
group_ends = torch.cat([group_starts[1:], torch.tensor([n], device="cpu")])
# Since groups are assigned in request order, gid → key is just list indexing
id_to_key = list(unique_keys)
# Build result dictionary efficiently
result: dict[RequestGroupKey, RequestGroupValue] = {}
for gid, (start, end) in enumerate(zip(group_starts.tolist(), group_ends.tolist())):
group_sorted_indices = sorted_order[start:end]
strategy_key, needs_probs_bool = id_to_key[gid]
indices_arr = group_sorted_indices.to(torch.int32)
# Convert to list for Python list indexing
group_sorted_indices_list = group_sorted_indices.tolist()
group_strategies = [strategies[i] for i in group_sorted_indices_list]
spec_mask = speculation_needs_probs[group_sorted_indices]
spec_indices = indices_arr[spec_mask]
processed_flags = need_processed[group_sorted_indices]
raw_flags = need_raw[group_sorted_indices]
if pin_memory:
indices_tensor = indices_arr.pin_memory()
spec_tensor = spec_indices.pin_memory()
processed_tensor = processed_flags.pin_memory()
raw_tensor = raw_flags.pin_memory()
else:
indices_tensor = indices_arr
spec_tensor = spec_indices
processed_tensor = processed_flags
raw_tensor = raw_flags
result[RequestGroupKey(strategy_key=strategy_key, needs_probs=needs_probs_bool)] = (
RequestGroupValue(
indices=indices_tensor,
strategies=group_strategies,
speculation_needs_probs_indices=spec_tensor,
need_processed_logprobs=processed_tensor,
need_raw_logprobs=raw_tensor,
)
)
need_processed_logprobs = (
req.py_logprobs_mode == LogprobMode.PROCESSED and req.return_log_probs
)
need_raw_logprobs = req.py_logprobs_mode == LogprobMode.RAW and req.return_log_probs
needs_probs = speculation_needs_probs or need_processed_logprobs
strategy_key = strategy_to_key(strategy)
group_dict_entry = group_dict[
RequestGroupKey(strategy_key=strategy_key, needs_probs=needs_probs)
]
group_dict_entry.indices.append(req_index)
group_dict_entry.strategies.append(strategy)
if speculation_needs_probs:
group_dict_entry.speculation_needs_probs_list.append(req_index)
group_dict_entry.need_processed_logprobs_list.append(need_processed_logprobs)
group_dict_entry.need_raw_logprobs_list.append(need_raw_logprobs)
return {
group_key: RequestGroupValue(
indices=torch.tensor(group_value.indices, pin_memory=pin_memory, dtype=torch.int32),
strategies=group_value.strategies,
speculation_needs_probs_indices=torch.tensor(
group_value.speculation_needs_probs_list, pin_memory=pin_memory, dtype=torch.int32
),
need_processed_logprobs=torch.tensor(
group_value.need_processed_logprobs_list, pin_memory=pin_memory, dtype=torch.bool
),
need_raw_logprobs=torch.tensor(
group_value.need_raw_logprobs_list, pin_memory=pin_memory, dtype=torch.bool
),
)
for group_key, group_value in group_dict.items()
}
return result
def add_token(
@ -962,6 +1083,22 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
"""Shape: batch_size, beam_width, sequence_length
Usage: Stores the original tokens for each beam.
This is used to recover the original tokens for each beam when streaming is enabled"""
speculation_needs_probs: list | None = None
"""Length: max_num_sequences. True if request has draft tokens and non-greedy sampling."""
need_processed: list | None = None
"""Length: max_num_sequences. True if logprob mode is PROCESSED and return_log_probs is set."""
need_raw: list | None = None
"""Length: max_num_sequences. True if logprob mode is RAW and return_log_probs is set."""
needs_probs: list | None = None
"""Length: max_num_sequences. True if speculation_needs_probs or need_processed."""
strategies: list | None = None
"""Length: max_num_sequences. Stores cached Strategy tuple for each seq_slot."""
uses_beam_search: list | None = None
"""Length: max_num_sequences. True if max_beam_width > 1 for this slot."""
slots_needing_recompute: set | None = None
"""Slots where strategy needs (re)computation. Populated in setup_sampler_step."""
non_greedy_slots: set | None = None
"""Slots with non-greedy strategies. Used to limit draft-token checks."""
def __post_init__(self):
assert self.new_tokens.shape == self.finish_reasons.shape
@ -991,6 +1128,17 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
predecessor_beams: torch.Tensor | None = None
original_tokens: torch.Tensor | None = None
first_finish_reasons: torch.Tensor | None = None
# Use Python lists instead of tensors to avoid .item() overhead in hot loops
speculation_needs_probs: list = [False] * self.max_num_sequences
need_processed: list = [False] * self.max_num_sequences
need_raw: list = [False] * self.max_num_sequences
needs_probs: list = [False] * self.max_num_sequences
strategies: list = [None] * self.max_num_sequences
uses_beam_search: list = [False] * self.max_num_sequences
slots_needing_recompute: set = set()
non_greedy_slots: set = set()
if self._use_beam_search:
cache_indirection = torch.empty(
self.CACHE_INDIRECTION_SHAPE, device="cuda", dtype=torch.int
@ -1018,6 +1166,14 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
predecessor_beams=predecessor_beams,
original_tokens=original_tokens,
first_finish_reasons=first_finish_reasons,
speculation_needs_probs=speculation_needs_probs,
need_processed=need_processed,
need_raw=need_raw,
needs_probs=needs_probs,
strategies=strategies,
uses_beam_search=uses_beam_search,
slots_needing_recompute=slots_needing_recompute,
non_greedy_slots=non_greedy_slots,
)
@dataclass(frozen=True, kw_only=True)
@ -1528,7 +1684,8 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
for request in scheduled_requests.context_requests:
if self._is_new_request(request):
assert request.py_seq_slot is not None
seq_slots.append(request.py_seq_slot)
slot = request.py_seq_slot
seq_slots.append(slot)
max_lens.append(
min(self.max_seq_len, request.orig_prompt_len + request.py_max_new_tokens)
)
@ -1539,6 +1696,13 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
raise ValueError("Beam search does not support multiple logprobs")
prompt_lens.append(request.py_prompt_len)
# Initialize cached data for this slot (prevents stale data from previous request)
self.store.strategies[slot] = None
self.store.uses_beam_search[slot] = _get_max_beam_width(request) > 1
# Mark slot for strategy recomputation in _group_requests_by_strategy_key
self.store.slots_needing_recompute.add(slot)
self.store.non_greedy_slots.discard(slot) # reset until strategy is computed
if len(seq_slots) > 0:
full_list = [seq_slots, max_lens, end_ids]
if self._use_beam_search:
@ -2450,8 +2614,10 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
grouped_requests = _group_requests_by_strategy_key(
requests,
pin_memory=True,
vocab_size=logits_cuda.size(1),
strategy_to_key=self._grouped_sampler_cls.strategy_grouping_key,
store=self.store,
seq_slots=seq_slots,
vocab_size=logits_cuda.size(1), # Dummy value; strategy should already be cached
)
grouped_requests_with_metadata = self._add_metadata_to_grouped_requests(
requests,

View File

@ -84,9 +84,6 @@ class TestStrategySelection:
sampling_config: SamplingConfig
is_context_init_state: bool # Torch sampler accesses this, but it does not affect this test
def __init__(self):
self._py_sampling_strategy: Strategy | None = None
def get_beam_width_by_iter(
self, for_next_iteration: bool = False
) -> int: # Torch sampler accesses this, but it does not affect this test
@ -1227,6 +1224,24 @@ class TestBatchedSampling:
)
)
@staticmethod
def _init_store_for_new_requests(
sampler: TorchSampler,
scheduled_requests: ScheduledRequests,
) -> None:
"""Initialize store for request slots that haven't been through setup_sampler_step.
In production, setup_sampler_step registers each new request's slot in
store.slots_needing_recompute so that _group_requests_by_strategy_key
knows to compute its strategy. Tests skip setup_sampler_step, so we
replicate the relevant initialization here to exercise the same code
path as production.
"""
for req in scheduled_requests.all_requests():
slot = req.py_seq_slot
if sampler.store.strategies[slot] is None:
sampler.store.slots_needing_recompute.add(slot)
def _sample(
self,
sampler: TorchSampler,
@ -1241,6 +1256,9 @@ class TestBatchedSampling:
Optionally, run sampling repeatedly, e.g., to gather statistics.
"""
assert not scheduled_requests.context_requests
# Simulate the store initialization that setup_sampler_step performs
# for new requests in production.
self._init_store_for_new_requests(sampler, scheduled_requests)
num_actual_repeats = num_repeats if num_repeats is not None else 1
T = TypeVar("T")