mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 07:53:55 +08:00
[TRTLLM-10791][feat] TorchSampler general host time optimization (#11141)
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
This commit is contained in:
parent
4b2b1d146b
commit
cb1d8d130f
@ -1,14 +1,10 @@
|
||||
from copy import copy, deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
import tensorrt_llm.bindings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tensorrt_llm._torch.pyexecutor.sampler import Strategy
|
||||
|
||||
from tensorrt_llm._torch.shared_tensor import SharedTensorContainer
|
||||
from tensorrt_llm.bindings import executor as tllm_executor
|
||||
from tensorrt_llm.executor.result import TokenLogprobs
|
||||
@ -676,8 +672,6 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
|
||||
additional_outputs=additional_outputs)
|
||||
self.child_requests = []
|
||||
|
||||
self._py_sampling_strategy: "Strategy | None" = None
|
||||
|
||||
self._py_embedding_bias_1d: Optional[torch.Tensor] = None
|
||||
if hasattr(self, 'embedding_bias') and self.embedding_bias is not None:
|
||||
# Pre-squeeze to 1D if needed (remove batch dimension)
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
import enum
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict, namedtuple
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from concurrent import futures
|
||||
from dataclasses import dataclass
|
||||
@ -369,10 +369,6 @@ def _get_max_beam_width(request: LlmRequest) -> int:
|
||||
return max_beam_width
|
||||
|
||||
|
||||
def _request_sampling_params_cachable(params: UtilsSamplingParams) -> bool:
|
||||
return not params.use_beam_search
|
||||
|
||||
|
||||
def _request_get_sampling_params(request: LlmRequest) -> UtilsSamplingParams:
|
||||
sampling_config = request.sampling_config
|
||||
temperature = _unwrap_singleton(cast(Optional[list[float]], sampling_config.temperature))
|
||||
@ -393,16 +389,13 @@ def _request_get_sampling_params(request: LlmRequest) -> UtilsSamplingParams:
|
||||
|
||||
|
||||
def _request_strategy(request: LlmRequest, *, vocab_size: int) -> Strategy:
|
||||
# We try to cache the resolved strategy on the request object, as it's not cheap enough to
|
||||
# resolve it on every iteration.
|
||||
if request._py_sampling_strategy is not None:
|
||||
return request._py_sampling_strategy
|
||||
"""Resolve the sampling strategy for a request.
|
||||
|
||||
Note: Callers inside _group_requests_by_strategy_key benefit from store.strategies
|
||||
caching, which ensures this function is called at most once per request per slot.
|
||||
"""
|
||||
params = _request_get_sampling_params(request)
|
||||
sampling_strategy = resolve_sampling_strategy(params, vocab_size=vocab_size)
|
||||
if _request_sampling_params_cachable(params):
|
||||
request._py_sampling_strategy = resolve_sampling_strategy(params, vocab_size=vocab_size)
|
||||
return sampling_strategy
|
||||
return resolve_sampling_strategy(params, vocab_size=vocab_size)
|
||||
|
||||
|
||||
def _group_requests_by_strategy_key(
|
||||
@ -410,63 +403,191 @@ def _group_requests_by_strategy_key(
|
||||
*,
|
||||
strategy_to_key: Callable[[Strategy], GenericStrategyKeyType],
|
||||
pin_memory: bool = False,
|
||||
store: "TorchSampler.Store",
|
||||
seq_slots: torch.Tensor,
|
||||
vocab_size: int,
|
||||
) -> dict[RequestGroupKey[GenericStrategyKeyType], RequestGroupValue]:
|
||||
# NB: Client code relies on request indices in returned torch.Tensor being ordered
|
||||
# by ascending "requests" index.
|
||||
RequestGroupValueBuilder = namedtuple(
|
||||
"RequestGroupValueBuilder",
|
||||
[
|
||||
"indices",
|
||||
"strategies",
|
||||
"speculation_needs_probs_list",
|
||||
"need_processed_logprobs_list",
|
||||
"need_raw_logprobs_list",
|
||||
],
|
||||
"""
|
||||
Optimized implementation with vectorized boolean operations and efficient grouping.
|
||||
|
||||
NB: Client code relies on request indices in returned torch.Tensor being sorted.
|
||||
"""
|
||||
# Convert to list for efficient indexing
|
||||
requests_list = list(requests) if not isinstance(requests, list) else requests
|
||||
n = len(requests_list)
|
||||
|
||||
if n == 0:
|
||||
return {}
|
||||
|
||||
assert not seq_slots.is_cuda, "seq_slots is expected to be a host tensor"
|
||||
seq_slots_list = seq_slots.tolist()
|
||||
|
||||
# Get strategies from cache, only recomputing for slots that need it.
|
||||
# Recompute is needed for:
|
||||
# - Uncached slots (strategy is None) — recorded in store.slots_needing_recompute
|
||||
# - Beam search (beam_width_in changes) — kept in slots_needing_recompute permanently
|
||||
# - Speculative decoding (draft_tokens can change) — checked for non-greedy slots only
|
||||
|
||||
# Build strategies from cache in one shot (C-level list comprehension, ~50ns/elem)
|
||||
s_strategies = store.strategies
|
||||
strategies = [s_strategies[slot] for slot in seq_slots_list]
|
||||
|
||||
# Build slot→request_index mapping for targeted access
|
||||
slot_to_idx = {slot: i for i, slot in enumerate(seq_slots_list)}
|
||||
active_slots = set(slot_to_idx)
|
||||
|
||||
# 1) Slots pre-recorded for recompute (context-phase or beam search)
|
||||
recompute_batch_slots = store.slots_needing_recompute & active_slots
|
||||
|
||||
# 2) Non-greedy slots where draft-token status may have changed
|
||||
# (For greedy: current_has_draft is always False, matching cached, so never stale)
|
||||
draft_check_slots = (store.non_greedy_slots & active_slots) - recompute_batch_slots
|
||||
for slot in draft_check_slots:
|
||||
i = slot_to_idx[slot]
|
||||
has_draft = bool(requests_list[i].py_draft_tokens)
|
||||
if store.speculation_needs_probs[slot] != has_draft:
|
||||
# Draft-token status changed — only update the affected flags.
|
||||
# The strategy itself doesn't depend on draft tokens (only on sampling params).
|
||||
store.speculation_needs_probs[slot] = has_draft
|
||||
store.needs_probs[slot] = has_draft or store.need_processed[slot]
|
||||
|
||||
# 3) Full recompute for the pre-recorded slots.
|
||||
# Every slot with a None strategy must already be in slots_needing_recompute
|
||||
# (populated by setup_sampler_step when a new request arrives).
|
||||
assert None not in strategies or all(
|
||||
seq_slots_list[i] in recompute_batch_slots for i in range(n) if strategies[i] is None
|
||||
), (
|
||||
"Found slots with uncached strategies not registered in slots_needing_recompute. "
|
||||
"Ensure setup_sampler_step is called before sample_async for new requests."
|
||||
)
|
||||
|
||||
group_dict: dict[RequestGroupKey[GenericStrategyKeyType], RequestGroupValueBuilder] = (
|
||||
defaultdict(lambda: RequestGroupValueBuilder([], [], [], [], []))
|
||||
for slot in recompute_batch_slots:
|
||||
i = slot_to_idx[slot]
|
||||
request = requests_list[i]
|
||||
has_draft_tokens = bool(request.py_draft_tokens)
|
||||
|
||||
strategy = _request_strategy(request, vocab_size=vocab_size)
|
||||
store.strategies[slot] = strategy
|
||||
strategies[i] = strategy
|
||||
|
||||
is_greedy = strategy == GREEDY
|
||||
store.speculation_needs_probs[slot] = has_draft_tokens and not is_greedy
|
||||
store.need_processed[slot] = (
|
||||
request.py_logprobs_mode == LogprobMode.PROCESSED and request.return_log_probs
|
||||
)
|
||||
store.need_raw[slot] = (
|
||||
request.py_logprobs_mode == LogprobMode.RAW and request.return_log_probs
|
||||
)
|
||||
store.needs_probs[slot] = store.speculation_needs_probs[slot] or store.need_processed[slot]
|
||||
|
||||
# Track non-greedy slots for future draft-token checks
|
||||
if is_greedy:
|
||||
store.non_greedy_slots.discard(slot)
|
||||
else:
|
||||
store.non_greedy_slots.add(slot)
|
||||
|
||||
# Keep beam-search slots in the recompute set (they always need it);
|
||||
# remove everything else (strategy is now cached).
|
||||
if not store.uses_beam_search[slot]:
|
||||
store.slots_needing_recompute.discard(slot)
|
||||
|
||||
# Gather flags using list comprehension (faster than append in loop)
|
||||
needs_probs = torch.tensor(
|
||||
[store.needs_probs[slot] for slot in seq_slots_list], dtype=torch.bool, device="cpu"
|
||||
)
|
||||
speculation_needs_probs = torch.tensor(
|
||||
[store.speculation_needs_probs[slot] for slot in seq_slots_list],
|
||||
dtype=torch.bool,
|
||||
device="cpu",
|
||||
)
|
||||
need_processed = torch.tensor(
|
||||
[store.need_processed[slot] for slot in seq_slots_list], dtype=torch.bool, device="cpu"
|
||||
)
|
||||
need_raw = torch.tensor(
|
||||
[store.need_raw[slot] for slot in seq_slots_list], dtype=torch.bool, device="cpu"
|
||||
)
|
||||
# Build strategy ID mapping for vectorized comparison (all on CPU).
|
||||
# NB: set() does not preserve insertion order, so we use dict.fromkeys() to deduplicate while preserving order.
|
||||
unique_strategies = list(dict.fromkeys(strategies))
|
||||
strategy_to_id = {s: idx for idx, s in enumerate(unique_strategies)}
|
||||
strategy_ids = torch.tensor(
|
||||
[strategy_to_id[s] for s in strategies], dtype=torch.int32, device="cpu"
|
||||
)
|
||||
|
||||
for req_index, req in enumerate(requests):
|
||||
strategy = _request_strategy(req, vocab_size=vocab_size)
|
||||
speculation_needs_probs = (
|
||||
# NB: This criterion needs to be consistent with the gating of rejection sampling in
|
||||
# process_draft_tokens.
|
||||
TorchSampler._speculation_could_use_rejection_sampling(req, strategy)
|
||||
# Pre-allocate group_ids array
|
||||
group_ids = torch.empty(n, dtype=torch.int32, device="cpu")
|
||||
|
||||
_next_gid = 0
|
||||
|
||||
def _provision_gid() -> int:
|
||||
nonlocal _next_gid
|
||||
gid = _next_gid
|
||||
_next_gid += 1
|
||||
return gid
|
||||
|
||||
unique_keys: defaultdict[tuple, int] = defaultdict(_provision_gid)
|
||||
|
||||
# Vectorized assignment: loop over unique combinations instead of all requests
|
||||
for sid, strategy in enumerate(unique_strategies):
|
||||
strat_mask = strategy_ids == sid
|
||||
|
||||
for needs_probs_val in (False, True):
|
||||
# Vectorized mask for this (strategy, needs_probs) group
|
||||
mask = strat_mask & (needs_probs if needs_probs_val else ~needs_probs)
|
||||
|
||||
if torch.any(mask):
|
||||
strategy_key = strategy_to_key(strategy) # Called once per group!
|
||||
key = (strategy_key, needs_probs_val)
|
||||
group_ids[mask] = unique_keys[key] # Vectorized assignment
|
||||
|
||||
# Efficient grouping using sort
|
||||
sorted_group_ids, sorted_order = torch.sort(group_ids, stable=True)
|
||||
# Use prepend to detect a "change" at position 0, giving us group_starts directly
|
||||
group_starts = torch.nonzero(
|
||||
torch.diff(sorted_group_ids, prepend=torch.tensor([-1], device="cpu")) != 0
|
||||
).squeeze(1)
|
||||
group_ends = torch.cat([group_starts[1:], torch.tensor([n], device="cpu")])
|
||||
# Since groups are assigned in request order, gid → key is just list indexing
|
||||
id_to_key = list(unique_keys)
|
||||
|
||||
# Build result dictionary efficiently
|
||||
result: dict[RequestGroupKey, RequestGroupValue] = {}
|
||||
|
||||
for gid, (start, end) in enumerate(zip(group_starts.tolist(), group_ends.tolist())):
|
||||
group_sorted_indices = sorted_order[start:end]
|
||||
strategy_key, needs_probs_bool = id_to_key[gid]
|
||||
|
||||
indices_arr = group_sorted_indices.to(torch.int32)
|
||||
# Convert to list for Python list indexing
|
||||
group_sorted_indices_list = group_sorted_indices.tolist()
|
||||
group_strategies = [strategies[i] for i in group_sorted_indices_list]
|
||||
spec_mask = speculation_needs_probs[group_sorted_indices]
|
||||
spec_indices = indices_arr[spec_mask]
|
||||
processed_flags = need_processed[group_sorted_indices]
|
||||
raw_flags = need_raw[group_sorted_indices]
|
||||
|
||||
if pin_memory:
|
||||
indices_tensor = indices_arr.pin_memory()
|
||||
spec_tensor = spec_indices.pin_memory()
|
||||
processed_tensor = processed_flags.pin_memory()
|
||||
raw_tensor = raw_flags.pin_memory()
|
||||
else:
|
||||
indices_tensor = indices_arr
|
||||
spec_tensor = spec_indices
|
||||
processed_tensor = processed_flags
|
||||
raw_tensor = raw_flags
|
||||
|
||||
result[RequestGroupKey(strategy_key=strategy_key, needs_probs=needs_probs_bool)] = (
|
||||
RequestGroupValue(
|
||||
indices=indices_tensor,
|
||||
strategies=group_strategies,
|
||||
speculation_needs_probs_indices=spec_tensor,
|
||||
need_processed_logprobs=processed_tensor,
|
||||
need_raw_logprobs=raw_tensor,
|
||||
)
|
||||
)
|
||||
need_processed_logprobs = (
|
||||
req.py_logprobs_mode == LogprobMode.PROCESSED and req.return_log_probs
|
||||
)
|
||||
need_raw_logprobs = req.py_logprobs_mode == LogprobMode.RAW and req.return_log_probs
|
||||
needs_probs = speculation_needs_probs or need_processed_logprobs
|
||||
strategy_key = strategy_to_key(strategy)
|
||||
group_dict_entry = group_dict[
|
||||
RequestGroupKey(strategy_key=strategy_key, needs_probs=needs_probs)
|
||||
]
|
||||
group_dict_entry.indices.append(req_index)
|
||||
group_dict_entry.strategies.append(strategy)
|
||||
if speculation_needs_probs:
|
||||
group_dict_entry.speculation_needs_probs_list.append(req_index)
|
||||
group_dict_entry.need_processed_logprobs_list.append(need_processed_logprobs)
|
||||
group_dict_entry.need_raw_logprobs_list.append(need_raw_logprobs)
|
||||
return {
|
||||
group_key: RequestGroupValue(
|
||||
indices=torch.tensor(group_value.indices, pin_memory=pin_memory, dtype=torch.int32),
|
||||
strategies=group_value.strategies,
|
||||
speculation_needs_probs_indices=torch.tensor(
|
||||
group_value.speculation_needs_probs_list, pin_memory=pin_memory, dtype=torch.int32
|
||||
),
|
||||
need_processed_logprobs=torch.tensor(
|
||||
group_value.need_processed_logprobs_list, pin_memory=pin_memory, dtype=torch.bool
|
||||
),
|
||||
need_raw_logprobs=torch.tensor(
|
||||
group_value.need_raw_logprobs_list, pin_memory=pin_memory, dtype=torch.bool
|
||||
),
|
||||
)
|
||||
for group_key, group_value in group_dict.items()
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def add_token(
|
||||
@ -962,6 +1083,22 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
|
||||
"""Shape: batch_size, beam_width, sequence_length
|
||||
Usage: Stores the original tokens for each beam.
|
||||
This is used to recover the original tokens for each beam when streaming is enabled"""
|
||||
speculation_needs_probs: list | None = None
|
||||
"""Length: max_num_sequences. True if request has draft tokens and non-greedy sampling."""
|
||||
need_processed: list | None = None
|
||||
"""Length: max_num_sequences. True if logprob mode is PROCESSED and return_log_probs is set."""
|
||||
need_raw: list | None = None
|
||||
"""Length: max_num_sequences. True if logprob mode is RAW and return_log_probs is set."""
|
||||
needs_probs: list | None = None
|
||||
"""Length: max_num_sequences. True if speculation_needs_probs or need_processed."""
|
||||
strategies: list | None = None
|
||||
"""Length: max_num_sequences. Stores cached Strategy tuple for each seq_slot."""
|
||||
uses_beam_search: list | None = None
|
||||
"""Length: max_num_sequences. True if max_beam_width > 1 for this slot."""
|
||||
slots_needing_recompute: set | None = None
|
||||
"""Slots where strategy needs (re)computation. Populated in setup_sampler_step."""
|
||||
non_greedy_slots: set | None = None
|
||||
"""Slots with non-greedy strategies. Used to limit draft-token checks."""
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.new_tokens.shape == self.finish_reasons.shape
|
||||
@ -991,6 +1128,17 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
|
||||
predecessor_beams: torch.Tensor | None = None
|
||||
original_tokens: torch.Tensor | None = None
|
||||
first_finish_reasons: torch.Tensor | None = None
|
||||
|
||||
# Use Python lists instead of tensors to avoid .item() overhead in hot loops
|
||||
speculation_needs_probs: list = [False] * self.max_num_sequences
|
||||
need_processed: list = [False] * self.max_num_sequences
|
||||
need_raw: list = [False] * self.max_num_sequences
|
||||
needs_probs: list = [False] * self.max_num_sequences
|
||||
strategies: list = [None] * self.max_num_sequences
|
||||
uses_beam_search: list = [False] * self.max_num_sequences
|
||||
slots_needing_recompute: set = set()
|
||||
non_greedy_slots: set = set()
|
||||
|
||||
if self._use_beam_search:
|
||||
cache_indirection = torch.empty(
|
||||
self.CACHE_INDIRECTION_SHAPE, device="cuda", dtype=torch.int
|
||||
@ -1018,6 +1166,14 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
|
||||
predecessor_beams=predecessor_beams,
|
||||
original_tokens=original_tokens,
|
||||
first_finish_reasons=first_finish_reasons,
|
||||
speculation_needs_probs=speculation_needs_probs,
|
||||
need_processed=need_processed,
|
||||
need_raw=need_raw,
|
||||
needs_probs=needs_probs,
|
||||
strategies=strategies,
|
||||
uses_beam_search=uses_beam_search,
|
||||
slots_needing_recompute=slots_needing_recompute,
|
||||
non_greedy_slots=non_greedy_slots,
|
||||
)
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
@ -1528,7 +1684,8 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
|
||||
for request in scheduled_requests.context_requests:
|
||||
if self._is_new_request(request):
|
||||
assert request.py_seq_slot is not None
|
||||
seq_slots.append(request.py_seq_slot)
|
||||
slot = request.py_seq_slot
|
||||
seq_slots.append(slot)
|
||||
max_lens.append(
|
||||
min(self.max_seq_len, request.orig_prompt_len + request.py_max_new_tokens)
|
||||
)
|
||||
@ -1539,6 +1696,13 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
|
||||
raise ValueError("Beam search does not support multiple logprobs")
|
||||
prompt_lens.append(request.py_prompt_len)
|
||||
|
||||
# Initialize cached data for this slot (prevents stale data from previous request)
|
||||
self.store.strategies[slot] = None
|
||||
self.store.uses_beam_search[slot] = _get_max_beam_width(request) > 1
|
||||
# Mark slot for strategy recomputation in _group_requests_by_strategy_key
|
||||
self.store.slots_needing_recompute.add(slot)
|
||||
self.store.non_greedy_slots.discard(slot) # reset until strategy is computed
|
||||
|
||||
if len(seq_slots) > 0:
|
||||
full_list = [seq_slots, max_lens, end_ids]
|
||||
if self._use_beam_search:
|
||||
@ -2450,8 +2614,10 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
|
||||
grouped_requests = _group_requests_by_strategy_key(
|
||||
requests,
|
||||
pin_memory=True,
|
||||
vocab_size=logits_cuda.size(1),
|
||||
strategy_to_key=self._grouped_sampler_cls.strategy_grouping_key,
|
||||
store=self.store,
|
||||
seq_slots=seq_slots,
|
||||
vocab_size=logits_cuda.size(1), # Dummy value; strategy should already be cached
|
||||
)
|
||||
grouped_requests_with_metadata = self._add_metadata_to_grouped_requests(
|
||||
requests,
|
||||
|
||||
@ -84,9 +84,6 @@ class TestStrategySelection:
|
||||
sampling_config: SamplingConfig
|
||||
is_context_init_state: bool # Torch sampler accesses this, but it does not affect this test
|
||||
|
||||
def __init__(self):
|
||||
self._py_sampling_strategy: Strategy | None = None
|
||||
|
||||
def get_beam_width_by_iter(
|
||||
self, for_next_iteration: bool = False
|
||||
) -> int: # Torch sampler accesses this, but it does not affect this test
|
||||
@ -1227,6 +1224,24 @@ class TestBatchedSampling:
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _init_store_for_new_requests(
|
||||
sampler: TorchSampler,
|
||||
scheduled_requests: ScheduledRequests,
|
||||
) -> None:
|
||||
"""Initialize store for request slots that haven't been through setup_sampler_step.
|
||||
|
||||
In production, setup_sampler_step registers each new request's slot in
|
||||
store.slots_needing_recompute so that _group_requests_by_strategy_key
|
||||
knows to compute its strategy. Tests skip setup_sampler_step, so we
|
||||
replicate the relevant initialization here to exercise the same code
|
||||
path as production.
|
||||
"""
|
||||
for req in scheduled_requests.all_requests():
|
||||
slot = req.py_seq_slot
|
||||
if sampler.store.strategies[slot] is None:
|
||||
sampler.store.slots_needing_recompute.add(slot)
|
||||
|
||||
def _sample(
|
||||
self,
|
||||
sampler: TorchSampler,
|
||||
@ -1241,6 +1256,9 @@ class TestBatchedSampling:
|
||||
Optionally, run sampling repeatedly, e.g., to gather statistics.
|
||||
"""
|
||||
assert not scheduled_requests.context_requests
|
||||
# Simulate the store initialization that setup_sampler_step performs
|
||||
# for new requests in production.
|
||||
self._init_store_for_new_requests(sampler, scheduled_requests)
|
||||
num_actual_repeats = num_repeats if num_repeats is not None else 1
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user