[TRTLLM-10312][perf] Improve performance of _write_finish_reasons in TorchSampler (#10459)

Signed-off-by: Stefan Niebler <82932102+stnie@users.noreply.github.com>
This commit is contained in:
Stefan Niebler 2026-01-29 17:06:09 +01:00 committed by GitHub
parent 80dd6e70c6
commit 7d31532850
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 104 additions and 58 deletions

View File

@ -929,6 +929,12 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
class Store:
new_tokens: torch.Tensor
"""Shape: See cpp DecoderState.getAllNewTokens()"""
max_lengths_tensor: torch.Tensor
"""Shape: batch_size
Usage: Stores the maximum lengths for each request"""
end_ids: torch.Tensor
"""Shape: batch_size
Usage: Stores the end ids for each request"""
finish_reasons: torch.Tensor
"""Shape: max_tokens, batch_size, beam_width
Usage: Stores the currently estimated finish_reasons for each request"""
@ -974,6 +980,8 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
# Tensors necessary for all sampling methods
new_tokens = int_tensor(self.NEW_TOKENS_SHAPE)
finish_reasons = int_tensor(self.NEW_TOKENS_SHAPE)
max_lengths_tensor = int_tensor(self.max_num_sequences)
end_ids = int_tensor(self.max_num_sequences)
# Only used for logprobs processing or beam search
sampled_log_probs = torch.empty(self.LOGPROBS_SHAPE, device="cuda", dtype=torch.float32)
@ -1007,6 +1015,8 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
return self.Store(
new_tokens=new_tokens,
finish_reasons=finish_reasons,
max_lengths_tensor=max_lengths_tensor,
end_ids=end_ids,
cache_indirection=cache_indirection,
cache_indirection_buffer=cache_indirection_buffer,
cum_log_probs=cum_log_probs,
@ -1072,6 +1082,9 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
FinishReason.CANCELLED,
] # `in FinishReason` clashes with PyBind11: `TypeError: 'pybind11_type' object is not iterable`
}
self._max_tokens_offset = torch.arange(
1, self.max_tokens + 1, device="cuda", dtype=torch.int32
).view(-1, 1, 1)
self._grouped_sampler_cls: Type[GroupedStrategySampler]
if IS_FLASHINFER_AVAILABLE and not args.disable_flashinfer_sampling:
@ -1525,14 +1538,47 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
return num_accepted_draft_tokens - 1
def setup_sampler_step(self, requests: ScheduledRequests):
def _is_new_request(self, request: LlmRequest) -> bool:
return (
not request.is_finished
and not request.py_is_draft
and (
(request.is_context_init_state and request.is_last_context_chunk)
or request.is_disagg_generation_transmission_complete
)
)
@override
def setup_sampler_step(self, scheduled_requests: ScheduledRequests):
"""Setup the sampler step for the requests
Args:
requests: list[LlmRequest]. The requests to setup the sampler step for
"""
if self._use_beam_search:
self._prepare_beam_search(requests.all_requests())
self._prepare_beam_search(scheduled_requests.all_requests())
seq_slots: list[int] = []
max_lens: list[int] = []
end_ids: list[int] = []
for request in scheduled_requests.context_requests:
if self._is_new_request(request):
seq_slots.append(request.py_seq_slot)
max_lens.append(
min(self.max_seq_len, request.orig_prompt_len + request.py_max_new_tokens)
)
end_ids.append(request.py_end_id if request.py_end_id is not None else -1)
if len(seq_slots) > 0:
full_list = [seq_slots, max_lens, end_ids]
# perform only a single copy
full_list_tensor = torch.tensor(full_list, device="cpu", dtype=torch.int32).to(
device="cuda", non_blocking=True
)
seq_slots_tensor = full_list_tensor[0]
max_lens_tensor = full_list_tensor[1]
end_ids_tensor = full_list_tensor[2]
self.store.max_lengths_tensor[seq_slots_tensor] = max_lens_tensor
self.store.end_ids[seq_slots_tensor] = end_ids_tensor
def _prepare_beam_search(
self,
@ -1544,10 +1590,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
initialize/reset the buffers for the request
"""
for request in requests:
if not request.is_finished and (
(request.is_context_init_state and request.is_last_context_chunk)
or request.is_disagg_generation_transmission_complete
):
if self._is_new_request(request):
if request.py_return_log_probs and request.py_num_logprobs > 1:
raise ValueError("Beam search does not support multiple logprobs")
self.store.cache_indirection[request.py_seq_slot, :, request.py_prompt_len].fill_(0)
@ -2083,13 +2126,9 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
dtype=torch.int64, # for index_fill_
pin_memory=True,
)
# necessary for beam search
seq_lens_host = (
torch.tensor(
[r.max_beam_num_tokens for r in requests], dtype=torch.int32, pin_memory=True
)
if self._use_beam_search
else None
# necessary for beam search and max_length checks
seq_lens_host = torch.tensor(
[r.max_beam_num_tokens for r in requests], dtype=torch.int32, pin_memory=True
)
new_tokens_host = self._process_requests(
scheduled_requests,
@ -2102,12 +2141,14 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
finish_reasons = self.store.finish_reasons
seq_slots = seq_slots_host.to(device="cuda", non_blocking=True)
seq_lens = seq_lens_host.to(device="cuda", non_blocking=True)
first_finish_reasons = self.store.first_finish_reasons if self._use_beam_search else None
self._write_finish_reasons(
requests,
finish_reasons=finish_reasons,
seq_slots=seq_slots,
seq_lens=seq_lens,
new_tokens=new_tokens,
first_finish_reasons=first_finish_reasons,
predecessor_beams=self.store.predecessor_beams,
@ -2760,6 +2801,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
*,
finish_reasons: torch.Tensor,
seq_slots: torch.Tensor,
seq_lens: torch.Tensor,
new_tokens: torch.Tensor,
first_finish_reasons: torch.Tensor | None = None,
predecessor_beams: torch.Tensor | None = None,
@ -2775,7 +2817,11 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
new_tokens: a buffer containing the newly generated tokens.
Shape: (max_tokens, max_batch_size, max_beam_width)
"""
tokens = new_tokens[:, seq_slots.to(device=new_tokens.device, non_blocking=True)]
# Seq Slots should be on the same device as new_tokens
assert seq_slots.device == new_tokens.device
assert seq_lens.device == new_tokens.device
tokens = new_tokens[:, seq_slots]
# we need to fill with NOT_FINISHED so we can differentiate between previous requests that had the same seq slot
finish_reasons.index_fill_(1, seq_slots, FinishReason.NOT_FINISHED.value)
@ -2801,12 +2847,12 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
)
batched_finish_reasons = torch.where(
self._are_max_length(requests),
self._are_max_length(seq_lens, self.store.max_lengths_tensor[seq_slots]),
self._reason_tensors[FinishReason.LENGTH],
batched_finish_reasons,
)
batched_finish_reasons = torch.where(
self._are_end_id(requests, tokens),
self._are_end_id(self.store.end_ids[seq_slots], tokens),
self._reason_tensors[FinishReason.END_ID],
batched_finish_reasons,
)
@ -2822,57 +2868,26 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
)
first_finish_reasons[seq_slots] = batched_first_finish_reasons
def _are_end_id(self, requests: list[LlmRequest], tokens: torch.Tensor) -> torch.Tensor:
end_ids_tensor = (
torch.tensor(
[
([req.py_end_id if req.py_end_id is not None else -1] * self.max_beam_width)
for req in requests
]
* self.max_tokens,
pin_memory=True,
dtype=tokens.dtype,
)
.view(self.max_tokens, len(requests), self.max_beam_width)
.to(device="cuda", non_blocking=True)
)
return tokens == end_ids_tensor
def _are_end_id(self, end_ids: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor:
return tokens == end_ids.view(1, -1, 1).expand(self.max_tokens, -1, self.max_beam_width)
def _are_max_length(self, requests: list[LlmRequest]) -> torch.Tensor:
def _are_max_length(self, seq_lens: torch.Tensor, max_seq_lens: torch.Tensor) -> torch.Tensor:
"""Checks which sequences are at or beyond the max length
Args:
requests: the requests to check the max length of
seq_lens: the sequence lengths of the requests to check the max length of
max_seq_lens: the maximum sequence lengths of the requests to check the max length of
Returns:
A tensor of shape (max_tokens, len(requests), max_beam_width)
where each element is True if the sequence is at or beyond the max length, False otherwise
"""
lengths_tensor = torch.tensor(
[
[
[
(req.get_num_tokens(beam_idx) + num_tokens)
for beam_idx in range(self.max_beam_width)
]
for req in requests
]
for num_tokens in range(1, self.max_tokens + 1)
]
lengths_tensor = (seq_lens.view(1, -1, 1) + self._max_tokens_offset).expand(
self.max_tokens, -1, self.max_beam_width
)
max_lengths_tensor = torch.tensor(
[
(
[min(req.py_max_new_tokens + req.orig_prompt_len, self.max_seq_len)]
* self.max_beam_width
)
for req in requests
]
* self.max_tokens
).view(self.max_tokens, len(requests), self.max_beam_width)
return (
(lengths_tensor >= max_lengths_tensor).pin_memory().to(device="cuda", non_blocking=True)
max_lengths_tensor = max_seq_lens.view(1, -1, 1).expand(
self.max_tokens, -1, self.max_beam_width
)
return lengths_tensor >= max_lengths_tensor
_PAD_ID = -1
"""Pad with negative, doesn't matter what"""

View File

@ -224,11 +224,18 @@ class MTPSampler(TorchSampler):
next_draft_tokens: torch.Tensor
new_tokens_lens: torch.Tensor
max_total_draft_tokens: torch.Tensor
finish_reasons: None = None # Necessary to satisfy the interface of TorchSampler.Store
# Necessary to satisfy the interface of TorchSampler.Store
finish_reasons: None = None
end_ids: None = None
max_lengths_tensor: None = None
def __post_init__(self):
pass # finish_reasons has no size to compare against new_tokens in MTPSampler
def setup_sampler_step(self, scheduled_requests: ScheduledRequests):
# MTPSampler does not need to setup additional buffers before the sampler step
pass
def __init__(self, args: TorchSampler.Args, *, nextn: int):
self.mapping = None
self.draft_len = nextn

View File

@ -701,16 +701,40 @@ class RequestCase:
seq_slots = torch.tensor(
[req.request.py_seq_slot for req in requests], device="cuda", dtype=torch.int64
)
seq_lens = torch.tensor(
[req.request.max_beam_num_tokens for req in requests], dtype=torch.int32, device="cuda"
)
new_tokens = torch.tensor(
[req.new_tokens for req in requests], dtype=torch.int32, device="cuda"
).T
sampler.store.new_tokens[:, seq_slots, BEAM] = new_tokens
max_seq_lens = torch.tensor(
[
min(
sampler.max_seq_len, req.request.orig_prompt_len + req.request.py_max_new_tokens
)
for req in requests
],
dtype=torch.int32,
device="cuda",
)
end_ids = torch.tensor(
[
req.request.py_end_id if req.request.py_end_id is not None else -1
for req in requests
],
dtype=torch.int32,
device="cuda",
)
sampler.store.max_lengths_tensor[seq_slots] = max_seq_lens
sampler.store.end_ids[seq_slots] = end_ids
def run():
sampler._write_finish_reasons(
[req.request for req in requests],
finish_reasons=sampler.store.finish_reasons,
new_tokens=sampler.store.new_tokens,
seq_lens=seq_lens,
seq_slots=seq_slots,
)