[TRTLLM-10030][perf] pin host memory and batch sampler setup in beam search (#11390)

Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
This commit is contained in:
mpikulski 2026-02-10 16:48:36 +01:00 committed by GitHub
parent 7d992972b2
commit 411fa9ff87
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1521,12 +1521,10 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
Args:
requests: list[LlmRequest]. The requests to setup the sampler step for
"""
if self._use_beam_search:
self._prepare_beam_search(scheduled_requests.all_requests())
seq_slots: list[int] = []
max_lens: list[int] = []
end_ids: list[int] = []
prompt_lens: list[int] = []
for request in scheduled_requests.context_requests:
if self._is_new_request(request):
assert request.py_seq_slot is not None
@ -1535,47 +1533,89 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
min(self.max_seq_len, request.orig_prompt_len + request.py_max_new_tokens)
)
end_ids.append(request.py_end_id if request.py_end_id is not None else -1)
if self._use_beam_search:
if request.py_return_log_probs and request.py_num_logprobs > 1:
raise ValueError("Beam search does not support multiple logprobs")
prompt_lens.append(request.py_prompt_len)
if len(seq_slots) > 0:
full_list = [seq_slots, max_lens, end_ids]
if self._use_beam_search:
full_list.append(prompt_lens)
# perform only a single copy
full_list_tensor = torch.tensor(full_list, device="cpu", dtype=torch.int32).to(
device="cuda", non_blocking=True
)
full_list_tensor = torch.tensor(
full_list, device="cpu", dtype=torch.int32, pin_memory=True
).to(device="cuda", non_blocking=True)
seq_slots_tensor = full_list_tensor[0]
max_lens_tensor = full_list_tensor[1]
end_ids_tensor = full_list_tensor[2]
self.store.max_lengths_tensor[seq_slots_tensor] = max_lens_tensor
self.store.end_ids[seq_slots_tensor] = end_ids_tensor
if self._use_beam_search:
prompt_lens_tensor = full_list_tensor[3]
self._prepare_beam_search(
seq_slots=seq_slots_tensor,
prompt_lens=prompt_lens_tensor,
)
def _prepare_beam_search(
self,
requests: list[LlmRequest],
seq_slots: torch.Tensor,
prompt_lens: torch.Tensor,
):
"""Prepare the beam search buffers for the requests
If the last context chunk is being processed,
initialize/reset the buffers for the request
"""
for request in requests:
if self._is_new_request(request):
if request.py_return_log_probs and request.py_num_logprobs > 1:
raise ValueError("Beam search does not support multiple logprobs")
assert self.store.cache_indirection is not None
assert self.store.cum_log_probs is not None
assert self.store.sampled_log_probs is not None
assert self.store.sampled_log_prob_ranks is not None
assert self.store.predecessor_beams is not None
assert self.store.first_finish_reasons is not None
assert self.store.original_tokens is not None
self.store.cache_indirection[request.py_seq_slot, :, request.py_prompt_len].fill_(0)
self.store.cum_log_probs[request.py_seq_slot].fill_(0)
self.store.sampled_log_probs[request.py_seq_slot].fill_(0)
self.store.sampled_log_prob_ranks[request.py_seq_slot].fill_(0)
self.store.predecessor_beams[request.py_seq_slot].fill_(0)
self.store.first_finish_reasons[request.py_seq_slot].fill_(
FinishReason.NOT_FINISHED.value
)
self.store.original_tokens[request.py_seq_slot].fill_(0)
assert self.store.cache_indirection is not None
assert self.store.cum_log_probs is not None
assert self.store.sampled_log_probs is not None
assert self.store.sampled_log_prob_ranks is not None
assert self.store.predecessor_beams is not None
assert self.store.first_finish_reasons is not None
assert self.store.original_tokens is not None
self.store.cache_indirection[seq_slots, :, prompt_lens] = torch.zeros(
(1, 1),
dtype=self.store.cache_indirection.dtype,
device=self.store.cache_indirection.device,
)
self.store.cum_log_probs[seq_slots] = torch.zeros(
(1,),
dtype=self.store.cum_log_probs.dtype,
device=self.store.cum_log_probs.device,
)
self.store.sampled_log_probs[seq_slots] = torch.zeros(
(1,),
dtype=self.store.sampled_log_probs.dtype,
device=self.store.sampled_log_probs.device,
)
self.store.sampled_log_prob_ranks[seq_slots] = torch.zeros(
(1,),
dtype=self.store.sampled_log_prob_ranks.dtype,
device=self.store.sampled_log_prob_ranks.device,
)
self.store.predecessor_beams[seq_slots] = torch.zeros(
(1,),
dtype=self.store.predecessor_beams.dtype,
device=self.store.predecessor_beams.device,
)
self.store.first_finish_reasons[seq_slots] = (
torch.tensor(
FinishReason.NOT_FINISHED.value,
pin_memory=True,
dtype=self.store.first_finish_reasons.dtype,
)
.to(self.store.first_finish_reasons.device, non_blocking=True)
.unsqueeze(0)
)
self.store.original_tokens[seq_slots] = torch.zeros(
(1,),
dtype=self.store.original_tokens.dtype,
device=self.store.original_tokens.device,
)
@torch.inference_mode()
def _process_draft_tokens_rejection_sampling(
@ -1994,6 +2034,7 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
for request_idx in grouped_requests[key].indices
],
dtype=torch.int32,
pin_memory=True,
).to(
device="cuda", non_blocking=True
), # end_ids should be on device for beam search
@ -3223,7 +3264,11 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
seq_slots: torch.Tensor,
seq_lens: torch.Tensor | None = None,
) -> torch.Tensor:
seq_slots = seq_slots.to(dtype=torch.int32) # int32 suffices here
seq_slots_int64 = seq_slots
seq_slots = torch.empty_like(
seq_slots_int64, dtype=torch.int32, pin_memory=True
) # int32 suffices here
seq_slots[:] = seq_slots_int64
raw_logits_cuda = model_outputs["logits"]