From cb1d8d130f784c7181633b57f5be9b349610c429 Mon Sep 17 00:00:00 2001 From: Yukun He <23156053+hyukn@users.noreply.github.com> Date: Fri, 13 Feb 2026 01:05:58 +0800 Subject: [PATCH] [TRTLLM-10791][feat] TorchSampler general host time optimization (#11141) Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/llm_request.py | 8 +- tensorrt_llm/_torch/pyexecutor/sampler.py | 296 ++++++++++++++---- .../_torch/sampler/test_torch_sampler.py | 24 +- 3 files changed, 253 insertions(+), 75 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index 092fde5a48..da39ea59fd 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -1,14 +1,10 @@ from copy import copy, deepcopy from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import torch import tensorrt_llm.bindings - -if TYPE_CHECKING: - from tensorrt_llm._torch.pyexecutor.sampler import Strategy - from tensorrt_llm._torch.shared_tensor import SharedTensorContainer from tensorrt_llm.bindings import executor as tllm_executor from tensorrt_llm.executor.result import TokenLogprobs @@ -676,8 +672,6 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest): additional_outputs=additional_outputs) self.child_requests = [] - self._py_sampling_strategy: "Strategy | None" = None - self._py_embedding_bias_1d: Optional[torch.Tensor] = None if hasattr(self, 'embedding_bias') and self.embedding_bias is not None: # Pre-squeeze to 1D if needed (remove batch dimension) diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index db6210da8d..b6841f0865 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -15,7 +15,7 @@ import enum import sys from abc import ABC, abstractmethod -from collections import defaultdict, namedtuple +from collections import defaultdict from collections.abc import Iterable from concurrent import futures from dataclasses import dataclass @@ -369,10 +369,6 @@ def _get_max_beam_width(request: LlmRequest) -> int: return max_beam_width -def _request_sampling_params_cachable(params: UtilsSamplingParams) -> bool: - return not params.use_beam_search - - def _request_get_sampling_params(request: LlmRequest) -> UtilsSamplingParams: sampling_config = request.sampling_config temperature = _unwrap_singleton(cast(Optional[list[float]], sampling_config.temperature)) @@ -393,16 +389,13 @@ def _request_get_sampling_params(request: LlmRequest) -> UtilsSamplingParams: def _request_strategy(request: LlmRequest, *, vocab_size: int) -> Strategy: - # We try to cache the resolved strategy on the request object, as it's not cheap enough to - # resolve it on every iteration. - if request._py_sampling_strategy is not None: - return request._py_sampling_strategy + """Resolve the sampling strategy for a request. + Note: Callers inside _group_requests_by_strategy_key benefit from store.strategies + caching, which ensures this function is called at most once per request per slot. + """ params = _request_get_sampling_params(request) - sampling_strategy = resolve_sampling_strategy(params, vocab_size=vocab_size) - if _request_sampling_params_cachable(params): - request._py_sampling_strategy = resolve_sampling_strategy(params, vocab_size=vocab_size) - return sampling_strategy + return resolve_sampling_strategy(params, vocab_size=vocab_size) def _group_requests_by_strategy_key( @@ -410,63 +403,191 @@ def _group_requests_by_strategy_key( *, strategy_to_key: Callable[[Strategy], GenericStrategyKeyType], pin_memory: bool = False, + store: "TorchSampler.Store", + seq_slots: torch.Tensor, vocab_size: int, ) -> dict[RequestGroupKey[GenericStrategyKeyType], RequestGroupValue]: - # NB: Client code relies on request indices in returned torch.Tensor being ordered - # by ascending "requests" index. - RequestGroupValueBuilder = namedtuple( - "RequestGroupValueBuilder", - [ - "indices", - "strategies", - "speculation_needs_probs_list", - "need_processed_logprobs_list", - "need_raw_logprobs_list", - ], + """ + Optimized implementation with vectorized boolean operations and efficient grouping. + + NB: Client code relies on request indices in returned torch.Tensor being sorted. + """ + # Convert to list for efficient indexing + requests_list = list(requests) if not isinstance(requests, list) else requests + n = len(requests_list) + + if n == 0: + return {} + + assert not seq_slots.is_cuda, "seq_slots is expected to be a host tensor" + seq_slots_list = seq_slots.tolist() + + # Get strategies from cache, only recomputing for slots that need it. + # Recompute is needed for: + # - Uncached slots (strategy is None) — recorded in store.slots_needing_recompute + # - Beam search (beam_width_in changes) — kept in slots_needing_recompute permanently + # - Speculative decoding (draft_tokens can change) — checked for non-greedy slots only + + # Build strategies from cache in one shot (C-level list comprehension, ~50ns/elem) + s_strategies = store.strategies + strategies = [s_strategies[slot] for slot in seq_slots_list] + + # Build slot→request_index mapping for targeted access + slot_to_idx = {slot: i for i, slot in enumerate(seq_slots_list)} + active_slots = set(slot_to_idx) + + # 1) Slots pre-recorded for recompute (context-phase or beam search) + recompute_batch_slots = store.slots_needing_recompute & active_slots + + # 2) Non-greedy slots where draft-token status may have changed + # (For greedy: current_has_draft is always False, matching cached, so never stale) + draft_check_slots = (store.non_greedy_slots & active_slots) - recompute_batch_slots + for slot in draft_check_slots: + i = slot_to_idx[slot] + has_draft = bool(requests_list[i].py_draft_tokens) + if store.speculation_needs_probs[slot] != has_draft: + # Draft-token status changed — only update the affected flags. + # The strategy itself doesn't depend on draft tokens (only on sampling params). + store.speculation_needs_probs[slot] = has_draft + store.needs_probs[slot] = has_draft or store.need_processed[slot] + + # 3) Full recompute for the pre-recorded slots. + # Every slot with a None strategy must already be in slots_needing_recompute + # (populated by setup_sampler_step when a new request arrives). + assert None not in strategies or all( + seq_slots_list[i] in recompute_batch_slots for i in range(n) if strategies[i] is None + ), ( + "Found slots with uncached strategies not registered in slots_needing_recompute. " + "Ensure setup_sampler_step is called before sample_async for new requests." ) - group_dict: dict[RequestGroupKey[GenericStrategyKeyType], RequestGroupValueBuilder] = ( - defaultdict(lambda: RequestGroupValueBuilder([], [], [], [], [])) + for slot in recompute_batch_slots: + i = slot_to_idx[slot] + request = requests_list[i] + has_draft_tokens = bool(request.py_draft_tokens) + + strategy = _request_strategy(request, vocab_size=vocab_size) + store.strategies[slot] = strategy + strategies[i] = strategy + + is_greedy = strategy == GREEDY + store.speculation_needs_probs[slot] = has_draft_tokens and not is_greedy + store.need_processed[slot] = ( + request.py_logprobs_mode == LogprobMode.PROCESSED and request.return_log_probs + ) + store.need_raw[slot] = ( + request.py_logprobs_mode == LogprobMode.RAW and request.return_log_probs + ) + store.needs_probs[slot] = store.speculation_needs_probs[slot] or store.need_processed[slot] + + # Track non-greedy slots for future draft-token checks + if is_greedy: + store.non_greedy_slots.discard(slot) + else: + store.non_greedy_slots.add(slot) + + # Keep beam-search slots in the recompute set (they always need it); + # remove everything else (strategy is now cached). + if not store.uses_beam_search[slot]: + store.slots_needing_recompute.discard(slot) + + # Gather flags using list comprehension (faster than append in loop) + needs_probs = torch.tensor( + [store.needs_probs[slot] for slot in seq_slots_list], dtype=torch.bool, device="cpu" + ) + speculation_needs_probs = torch.tensor( + [store.speculation_needs_probs[slot] for slot in seq_slots_list], + dtype=torch.bool, + device="cpu", + ) + need_processed = torch.tensor( + [store.need_processed[slot] for slot in seq_slots_list], dtype=torch.bool, device="cpu" + ) + need_raw = torch.tensor( + [store.need_raw[slot] for slot in seq_slots_list], dtype=torch.bool, device="cpu" + ) + # Build strategy ID mapping for vectorized comparison (all on CPU). + # NB: set() does not preserve insertion order, so we use dict.fromkeys() to deduplicate while preserving order. + unique_strategies = list(dict.fromkeys(strategies)) + strategy_to_id = {s: idx for idx, s in enumerate(unique_strategies)} + strategy_ids = torch.tensor( + [strategy_to_id[s] for s in strategies], dtype=torch.int32, device="cpu" ) - for req_index, req in enumerate(requests): - strategy = _request_strategy(req, vocab_size=vocab_size) - speculation_needs_probs = ( - # NB: This criterion needs to be consistent with the gating of rejection sampling in - # process_draft_tokens. - TorchSampler._speculation_could_use_rejection_sampling(req, strategy) + # Pre-allocate group_ids array + group_ids = torch.empty(n, dtype=torch.int32, device="cpu") + + _next_gid = 0 + + def _provision_gid() -> int: + nonlocal _next_gid + gid = _next_gid + _next_gid += 1 + return gid + + unique_keys: defaultdict[tuple, int] = defaultdict(_provision_gid) + + # Vectorized assignment: loop over unique combinations instead of all requests + for sid, strategy in enumerate(unique_strategies): + strat_mask = strategy_ids == sid + + for needs_probs_val in (False, True): + # Vectorized mask for this (strategy, needs_probs) group + mask = strat_mask & (needs_probs if needs_probs_val else ~needs_probs) + + if torch.any(mask): + strategy_key = strategy_to_key(strategy) # Called once per group! + key = (strategy_key, needs_probs_val) + group_ids[mask] = unique_keys[key] # Vectorized assignment + + # Efficient grouping using sort + sorted_group_ids, sorted_order = torch.sort(group_ids, stable=True) + # Use prepend to detect a "change" at position 0, giving us group_starts directly + group_starts = torch.nonzero( + torch.diff(sorted_group_ids, prepend=torch.tensor([-1], device="cpu")) != 0 + ).squeeze(1) + group_ends = torch.cat([group_starts[1:], torch.tensor([n], device="cpu")]) + # Since groups are assigned in request order, gid → key is just list indexing + id_to_key = list(unique_keys) + + # Build result dictionary efficiently + result: dict[RequestGroupKey, RequestGroupValue] = {} + + for gid, (start, end) in enumerate(zip(group_starts.tolist(), group_ends.tolist())): + group_sorted_indices = sorted_order[start:end] + strategy_key, needs_probs_bool = id_to_key[gid] + + indices_arr = group_sorted_indices.to(torch.int32) + # Convert to list for Python list indexing + group_sorted_indices_list = group_sorted_indices.tolist() + group_strategies = [strategies[i] for i in group_sorted_indices_list] + spec_mask = speculation_needs_probs[group_sorted_indices] + spec_indices = indices_arr[spec_mask] + processed_flags = need_processed[group_sorted_indices] + raw_flags = need_raw[group_sorted_indices] + + if pin_memory: + indices_tensor = indices_arr.pin_memory() + spec_tensor = spec_indices.pin_memory() + processed_tensor = processed_flags.pin_memory() + raw_tensor = raw_flags.pin_memory() + else: + indices_tensor = indices_arr + spec_tensor = spec_indices + processed_tensor = processed_flags + raw_tensor = raw_flags + + result[RequestGroupKey(strategy_key=strategy_key, needs_probs=needs_probs_bool)] = ( + RequestGroupValue( + indices=indices_tensor, + strategies=group_strategies, + speculation_needs_probs_indices=spec_tensor, + need_processed_logprobs=processed_tensor, + need_raw_logprobs=raw_tensor, + ) ) - need_processed_logprobs = ( - req.py_logprobs_mode == LogprobMode.PROCESSED and req.return_log_probs - ) - need_raw_logprobs = req.py_logprobs_mode == LogprobMode.RAW and req.return_log_probs - needs_probs = speculation_needs_probs or need_processed_logprobs - strategy_key = strategy_to_key(strategy) - group_dict_entry = group_dict[ - RequestGroupKey(strategy_key=strategy_key, needs_probs=needs_probs) - ] - group_dict_entry.indices.append(req_index) - group_dict_entry.strategies.append(strategy) - if speculation_needs_probs: - group_dict_entry.speculation_needs_probs_list.append(req_index) - group_dict_entry.need_processed_logprobs_list.append(need_processed_logprobs) - group_dict_entry.need_raw_logprobs_list.append(need_raw_logprobs) - return { - group_key: RequestGroupValue( - indices=torch.tensor(group_value.indices, pin_memory=pin_memory, dtype=torch.int32), - strategies=group_value.strategies, - speculation_needs_probs_indices=torch.tensor( - group_value.speculation_needs_probs_list, pin_memory=pin_memory, dtype=torch.int32 - ), - need_processed_logprobs=torch.tensor( - group_value.need_processed_logprobs_list, pin_memory=pin_memory, dtype=torch.bool - ), - need_raw_logprobs=torch.tensor( - group_value.need_raw_logprobs_list, pin_memory=pin_memory, dtype=torch.bool - ), - ) - for group_key, group_value in group_dict.items() - } + + return result def add_token( @@ -962,6 +1083,22 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin): """Shape: batch_size, beam_width, sequence_length Usage: Stores the original tokens for each beam. This is used to recover the original tokens for each beam when streaming is enabled""" + speculation_needs_probs: list | None = None + """Length: max_num_sequences. True if request has draft tokens and non-greedy sampling.""" + need_processed: list | None = None + """Length: max_num_sequences. True if logprob mode is PROCESSED and return_log_probs is set.""" + need_raw: list | None = None + """Length: max_num_sequences. True if logprob mode is RAW and return_log_probs is set.""" + needs_probs: list | None = None + """Length: max_num_sequences. True if speculation_needs_probs or need_processed.""" + strategies: list | None = None + """Length: max_num_sequences. Stores cached Strategy tuple for each seq_slot.""" + uses_beam_search: list | None = None + """Length: max_num_sequences. True if max_beam_width > 1 for this slot.""" + slots_needing_recompute: set | None = None + """Slots where strategy needs (re)computation. Populated in setup_sampler_step.""" + non_greedy_slots: set | None = None + """Slots with non-greedy strategies. Used to limit draft-token checks.""" def __post_init__(self): assert self.new_tokens.shape == self.finish_reasons.shape @@ -991,6 +1128,17 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin): predecessor_beams: torch.Tensor | None = None original_tokens: torch.Tensor | None = None first_finish_reasons: torch.Tensor | None = None + + # Use Python lists instead of tensors to avoid .item() overhead in hot loops + speculation_needs_probs: list = [False] * self.max_num_sequences + need_processed: list = [False] * self.max_num_sequences + need_raw: list = [False] * self.max_num_sequences + needs_probs: list = [False] * self.max_num_sequences + strategies: list = [None] * self.max_num_sequences + uses_beam_search: list = [False] * self.max_num_sequences + slots_needing_recompute: set = set() + non_greedy_slots: set = set() + if self._use_beam_search: cache_indirection = torch.empty( self.CACHE_INDIRECTION_SHAPE, device="cuda", dtype=torch.int @@ -1018,6 +1166,14 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin): predecessor_beams=predecessor_beams, original_tokens=original_tokens, first_finish_reasons=first_finish_reasons, + speculation_needs_probs=speculation_needs_probs, + need_processed=need_processed, + need_raw=need_raw, + needs_probs=needs_probs, + strategies=strategies, + uses_beam_search=uses_beam_search, + slots_needing_recompute=slots_needing_recompute, + non_greedy_slots=non_greedy_slots, ) @dataclass(frozen=True, kw_only=True) @@ -1528,7 +1684,8 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin): for request in scheduled_requests.context_requests: if self._is_new_request(request): assert request.py_seq_slot is not None - seq_slots.append(request.py_seq_slot) + slot = request.py_seq_slot + seq_slots.append(slot) max_lens.append( min(self.max_seq_len, request.orig_prompt_len + request.py_max_new_tokens) ) @@ -1539,6 +1696,13 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin): raise ValueError("Beam search does not support multiple logprobs") prompt_lens.append(request.py_prompt_len) + # Initialize cached data for this slot (prevents stale data from previous request) + self.store.strategies[slot] = None + self.store.uses_beam_search[slot] = _get_max_beam_width(request) > 1 + # Mark slot for strategy recomputation in _group_requests_by_strategy_key + self.store.slots_needing_recompute.add(slot) + self.store.non_greedy_slots.discard(slot) # reset until strategy is computed + if len(seq_slots) > 0: full_list = [seq_slots, max_lens, end_ids] if self._use_beam_search: @@ -2450,8 +2614,10 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin): grouped_requests = _group_requests_by_strategy_key( requests, pin_memory=True, - vocab_size=logits_cuda.size(1), strategy_to_key=self._grouped_sampler_cls.strategy_grouping_key, + store=self.store, + seq_slots=seq_slots, + vocab_size=logits_cuda.size(1), # Dummy value; strategy should already be cached ) grouped_requests_with_metadata = self._add_metadata_to_grouped_requests( requests, diff --git a/tests/unittest/_torch/sampler/test_torch_sampler.py b/tests/unittest/_torch/sampler/test_torch_sampler.py index 3557b5c038..5aa6a92eda 100644 --- a/tests/unittest/_torch/sampler/test_torch_sampler.py +++ b/tests/unittest/_torch/sampler/test_torch_sampler.py @@ -84,9 +84,6 @@ class TestStrategySelection: sampling_config: SamplingConfig is_context_init_state: bool # Torch sampler accesses this, but it does not affect this test - def __init__(self): - self._py_sampling_strategy: Strategy | None = None - def get_beam_width_by_iter( self, for_next_iteration: bool = False ) -> int: # Torch sampler accesses this, but it does not affect this test @@ -1227,6 +1224,24 @@ class TestBatchedSampling: ) ) + @staticmethod + def _init_store_for_new_requests( + sampler: TorchSampler, + scheduled_requests: ScheduledRequests, + ) -> None: + """Initialize store for request slots that haven't been through setup_sampler_step. + + In production, setup_sampler_step registers each new request's slot in + store.slots_needing_recompute so that _group_requests_by_strategy_key + knows to compute its strategy. Tests skip setup_sampler_step, so we + replicate the relevant initialization here to exercise the same code + path as production. + """ + for req in scheduled_requests.all_requests(): + slot = req.py_seq_slot + if sampler.store.strategies[slot] is None: + sampler.store.slots_needing_recompute.add(slot) + def _sample( self, sampler: TorchSampler, @@ -1241,6 +1256,9 @@ class TestBatchedSampling: Optionally, run sampling repeatedly, e.g., to gather statistics. """ assert not scheduled_requests.context_requests + # Simulate the store initialization that setup_sampler_step performs + # for new requests in production. + self._init_store_for_new_requests(sampler, scheduled_requests) num_actual_repeats = num_repeats if num_repeats is not None else 1 T = TypeVar("T")