diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index da33c6cb30..27e2b99b59 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -1521,12 +1521,10 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin): Args: requests: list[LlmRequest]. The requests to setup the sampler step for """ - if self._use_beam_search: - self._prepare_beam_search(scheduled_requests.all_requests()) - seq_slots: list[int] = [] max_lens: list[int] = [] end_ids: list[int] = [] + prompt_lens: list[int] = [] for request in scheduled_requests.context_requests: if self._is_new_request(request): assert request.py_seq_slot is not None @@ -1535,47 +1533,89 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin): min(self.max_seq_len, request.orig_prompt_len + request.py_max_new_tokens) ) end_ids.append(request.py_end_id if request.py_end_id is not None else -1) + + if self._use_beam_search: + if request.py_return_log_probs and request.py_num_logprobs > 1: + raise ValueError("Beam search does not support multiple logprobs") + prompt_lens.append(request.py_prompt_len) + if len(seq_slots) > 0: full_list = [seq_slots, max_lens, end_ids] + if self._use_beam_search: + full_list.append(prompt_lens) # perform only a single copy - full_list_tensor = torch.tensor(full_list, device="cpu", dtype=torch.int32).to( - device="cuda", non_blocking=True - ) + full_list_tensor = torch.tensor( + full_list, device="cpu", dtype=torch.int32, pin_memory=True + ).to(device="cuda", non_blocking=True) seq_slots_tensor = full_list_tensor[0] max_lens_tensor = full_list_tensor[1] end_ids_tensor = full_list_tensor[2] self.store.max_lengths_tensor[seq_slots_tensor] = max_lens_tensor self.store.end_ids[seq_slots_tensor] = end_ids_tensor + if self._use_beam_search: + prompt_lens_tensor = full_list_tensor[3] + self._prepare_beam_search( + seq_slots=seq_slots_tensor, + prompt_lens=prompt_lens_tensor, + ) + def _prepare_beam_search( self, - requests: list[LlmRequest], + seq_slots: torch.Tensor, + prompt_lens: torch.Tensor, ): """Prepare the beam search buffers for the requests If the last context chunk is being processed, initialize/reset the buffers for the request """ - for request in requests: - if self._is_new_request(request): - if request.py_return_log_probs and request.py_num_logprobs > 1: - raise ValueError("Beam search does not support multiple logprobs") - assert self.store.cache_indirection is not None - assert self.store.cum_log_probs is not None - assert self.store.sampled_log_probs is not None - assert self.store.sampled_log_prob_ranks is not None - assert self.store.predecessor_beams is not None - assert self.store.first_finish_reasons is not None - assert self.store.original_tokens is not None - self.store.cache_indirection[request.py_seq_slot, :, request.py_prompt_len].fill_(0) - self.store.cum_log_probs[request.py_seq_slot].fill_(0) - self.store.sampled_log_probs[request.py_seq_slot].fill_(0) - self.store.sampled_log_prob_ranks[request.py_seq_slot].fill_(0) - self.store.predecessor_beams[request.py_seq_slot].fill_(0) - self.store.first_finish_reasons[request.py_seq_slot].fill_( - FinishReason.NOT_FINISHED.value - ) - self.store.original_tokens[request.py_seq_slot].fill_(0) + assert self.store.cache_indirection is not None + assert self.store.cum_log_probs is not None + assert self.store.sampled_log_probs is not None + assert self.store.sampled_log_prob_ranks is not None + assert self.store.predecessor_beams is not None + assert self.store.first_finish_reasons is not None + assert self.store.original_tokens is not None + self.store.cache_indirection[seq_slots, :, prompt_lens] = torch.zeros( + (1, 1), + dtype=self.store.cache_indirection.dtype, + device=self.store.cache_indirection.device, + ) + self.store.cum_log_probs[seq_slots] = torch.zeros( + (1,), + dtype=self.store.cum_log_probs.dtype, + device=self.store.cum_log_probs.device, + ) + self.store.sampled_log_probs[seq_slots] = torch.zeros( + (1,), + dtype=self.store.sampled_log_probs.dtype, + device=self.store.sampled_log_probs.device, + ) + self.store.sampled_log_prob_ranks[seq_slots] = torch.zeros( + (1,), + dtype=self.store.sampled_log_prob_ranks.dtype, + device=self.store.sampled_log_prob_ranks.device, + ) + self.store.predecessor_beams[seq_slots] = torch.zeros( + (1,), + dtype=self.store.predecessor_beams.dtype, + device=self.store.predecessor_beams.device, + ) + self.store.first_finish_reasons[seq_slots] = ( + torch.tensor( + FinishReason.NOT_FINISHED.value, + pin_memory=True, + dtype=self.store.first_finish_reasons.dtype, + ) + .to(self.store.first_finish_reasons.device, non_blocking=True) + .unsqueeze(0) + ) + self.store.original_tokens[seq_slots] = torch.zeros( + (1,), + dtype=self.store.original_tokens.dtype, + device=self.store.original_tokens.device, + ) @torch.inference_mode() def _process_draft_tokens_rejection_sampling( @@ -1994,6 +2034,7 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin): for request_idx in grouped_requests[key].indices ], dtype=torch.int32, + pin_memory=True, ).to( device="cuda", non_blocking=True ), # end_ids should be on device for beam search @@ -3223,7 +3264,11 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin): seq_slots: torch.Tensor, seq_lens: torch.Tensor | None = None, ) -> torch.Tensor: - seq_slots = seq_slots.to(dtype=torch.int32) # int32 suffices here + seq_slots_int64 = seq_slots + seq_slots = torch.empty_like( + seq_slots_int64, dtype=torch.int32, pin_memory=True + ) # int32 suffices here + seq_slots[:] = seq_slots_int64 raw_logits_cuda = model_outputs["logits"]