From 196d94a4197c96f4a9a79ad7c174ced466dd55d4 Mon Sep 17 00:00:00 2001 From: mpikulski <206748156+ixlmar@users.noreply.github.com> Date: Mon, 9 Feb 2026 16:13:58 +0100 Subject: [PATCH] [TRTLLM-10030][perf] avoid syncs in beam search + other improvements (#11349) Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/sampler.py | 300 ++++++++++-------- .../_torch/sampler/test_beam_search.py | 22 +- 2 files changed, 175 insertions(+), 147 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 31e56ccb05..9fdaf22fb1 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -766,6 +766,13 @@ class BeamHistory: cum_logprobs: torch.Tensor | None = None +BeamHistoryBuilder: TypeAlias = Callable[[], BeamHistory | None] +"""Builder for BeamHistory. + +Used to defer possibly unnecessary host-tensor construction until update_requests(). +""" + + @dataclass(kw_only=True) class SamplingRequestsMetadata: req_num_generated_tokens: torch.Tensor @@ -789,7 +796,7 @@ class SampleStateTensorsHostTorch(SampleStateTensors): @dataclass(kw_only=True) class SampleStateTorch(SampleState[SampleStateTensorsHostTorch, SampleStateTensors]): - beam_histories: list[BeamHistory | None] | None = None + beam_history_builders: list[BeamHistoryBuilder | None] | None = None class AsyncWorkerMixin: @@ -1249,9 +1256,6 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin): self, token_tensor: torch.Tensor, logprobs_tensor: torch.Tensor, - sampled_log_probs_indices: torch.Tensor | None, - sampled_log_probs_vals: torch.Tensor | None, - sampled_log_probs_rank: torch.Tensor | None, ) -> list[list[dict[int, Logprob]]]: """Convert the logprobs tensor to a list of lists of dictionaries of Logprob objects @@ -1260,9 +1264,6 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin): args: token_tensor: torch.Tensor. Shape: beam_width, num_tokens, num_logprobs logprobs_tensor: torch.Tensor. Shape: beam_width, num_tokens, num_logprobs - sampled_log_probs_indices: torch.Tensor | None. Shape: num_tokens - sampled_log_probs_vals: torch.Tensor | None. Shape: num_tokens - sampled_log_probs_rank: torch.Tensor | None. Shape: num_tokens output: list[list[dict[int, Logprob]]]. Shape: (beam_width, num_tokens) """ @@ -1274,38 +1275,13 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin): token_log_probs: list[list[dict[int, Logprob]]] = [] token_list = token_tensor.tolist() logprobs_list = logprobs_tensor.tolist() - sampled_log_probs_indices_list: list[int] | None = None - sampled_log_probs_vals_list: list[float] | None = None - sampled_log_probs_rank_list: list[int] | None = None - if sampled_log_probs_indices is not None: - sampled_log_probs_indices_list = sampled_log_probs_indices.tolist() - assert sampled_log_probs_vals is not None, "sampled_log_probs_vals must be provided" - assert sampled_log_probs_rank is not None, "sampled_log_probs_rank must be provided" - sampled_log_probs_vals_list = sampled_log_probs_vals.tolist() - sampled_log_probs_rank_list = sampled_log_probs_rank.tolist() for beam_idx in range(token_tensor.shape[0]): beam_token_log_probs: list[dict[int, Logprob]] = [] - for step_idx, (topk_token, topk_logprob) in enumerate( - zip(token_list[beam_idx], logprobs_list[beam_idx]) - ): + for topk_token, topk_logprob in zip(token_list[beam_idx], logprobs_list[beam_idx]): logprobs = { token: Logprob(logprob=logprob, rank=rank + 1) for rank, (token, logprob) in enumerate(zip(topk_token, topk_logprob)) } - if sampled_log_probs_indices is not None: - assert beam_idx == DEFAULT_BEAM_IDX, ( - "beam search does not need to explicitly handle sampled log probs" - ) - assert sampled_log_probs_indices_list is not None - assert sampled_log_probs_vals_list is not None - assert sampled_log_probs_rank_list is not None - if sampled_log_probs_indices_list[step_idx] not in logprobs: - logprobs[sampled_log_probs_indices_list[step_idx]] = Logprob( - logprob=sampled_log_probs_vals_list[step_idx], - rank=max( - token_tensor.shape[2] + 1, sampled_log_probs_rank_list[step_idx] - ), - ) beam_token_log_probs.append(logprobs) token_log_probs.append(beam_token_log_probs) @@ -1732,7 +1708,12 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin): request, new_tokens_list=new_tokens_list, new_tokens_tensor=new_tokens_tensor ) - def _get_logprobs_from_request(self, request: LlmRequest) -> tuple[torch.Tensor, torch.Tensor]: + def _get_logprobs_from_request( + self, + request: LlmRequest, + pin_memory: bool = True, + preallocate_extra_steps: int = 0, + ) -> tuple[torch.Tensor, torch.Tensor]: """Extract the logprobs from the request Returns: @@ -1743,25 +1724,27 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin): assert request.py_num_logprobs == 0, ( "Beam search only supports returning the sampled logprob per token" ) - logprobs_tensor = torch.empty( + logprobs_tensor_full = torch.empty( ( request.sampling_config.beam_width, - num_generated_tokens, + num_generated_tokens + preallocate_extra_steps, request.py_num_logprobs + 1, ), - device="cuda", + pin_memory=pin_memory, dtype=torch.float32, ) - logprobs_indices_tensor = torch.empty( + logprobs_indices_tensor_full = torch.empty( ( request.sampling_config.beam_width, - num_generated_tokens, + num_generated_tokens + preallocate_extra_steps, request.py_num_logprobs + 1, ), - device="cuda", + pin_memory=pin_memory, dtype=torch.int32, ) - if hasattr(request.py_result._log_probs, "log_probs"): + logprobs_tensor = logprobs_tensor_full[:, :-preallocate_extra_steps, :] + logprobs_indices_tensor = logprobs_indices_tensor_full[:, :-preallocate_extra_steps, :] + if logprobs_tensor.numel() > 0: logprobs_list = request.py_result.log_probs assert logprobs_list is not None for beam_idx, beam_logprobs in enumerate(logprobs_list): @@ -1770,12 +1753,14 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin): assert value.rank is not None logprobs_tensor[beam_idx, token_idx, value.rank - 1] = value.logprob logprobs_indices_tensor[beam_idx, token_idx, value.rank - 1] = key - return logprobs_tensor, logprobs_indices_tensor + return logprobs_tensor_full, logprobs_indices_tensor_full - def _create_beam_history( + def _prepare_beam_history( self, request: LlmRequest, - ) -> BeamHistory | None: + *, + finish_reasons: torch.Tensor, + ) -> BeamHistoryBuilder | None: """Correct the stored tokens for each beam and return it as a BeamHistory object. Beam Search sampling only adds new tokens to the beam. @@ -1784,9 +1769,30 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin): If logprobs are requested, the function also corrects the stored logprobs for each beam. The function returns a BeamHistory object that contains the corrected tokens and logprobs for each beam. + Note: To defer the decision whether or not to skip BeamHistory construction until update_requests(), only + a builder (BeamHistoryBuilder) is returned here. The builder contains host tensors which are + being populated asynchronously. Hence, it can only be invoked after async D2H copies have completed, + e.g., after awaiting state.sampler_event in update_requests. + arguments: request: The request to create the beam history for + finish_reasons: The first finish reason encountered for each beam of the request. + Shape: (max_tokens, max_beam_width) """ + + # Gather data used for skipping beam history processing + need_finalize_due_to_stop_words = self._check_stop_words_length(request) + if need_finalize_due_to_stop_words: + need_history = torch.tensor(True) + else: + should_stop = self._check_beam_search_stop_criteria( + request, + finish_reasons=finish_reasons, + ) + need_history = should_stop + # enqueue async D2H copy + need_history = self._copy_to_host(need_history) + num_tokens = request.max_beam_num_tokens + 1 # last token is not yet added prompt_length = request.py_prompt_len num_generated_tokens = num_tokens - prompt_length @@ -1795,73 +1801,98 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin): if num_generated_tokens == 0 or request.state == LlmRequestState.GENERATION_COMPLETE: # early return if no tokens have been generated yet or the request is already finished return None + assert self.store.cache_indirection is not None assert self.store.original_tokens is not None - assert self.store.sampled_log_probs is not None cache_indirection = self.store.cache_indirection[ request.py_seq_slot, :num_beams, prompt_length:num_tokens ] current_path = self.store.original_tokens[ request.py_seq_slot, :num_beams, prompt_length:num_tokens ] - new_path = torch.zeros_like(current_path) - # initialize each beam with its own index + # enqueue async D2H copies + cache_indirection = self._copy_to_host(cache_indirection) + current_path = self._copy_to_host(current_path) + + def _post_process_path() -> torch.Tensor: + # Gather the correct tokens for each beam + new_path = torch.zeros_like(current_path) + torch.gather(input=current_path, dim=0, index=cache_indirection, out=new_path) + return new_path - # Gather the correct tokens and logprobs for each beam - torch.gather(input=current_path, dim=0, index=cache_indirection, out=new_path) if request.py_return_log_probs: + assert self.store.sampled_log_probs is not None assert self.store.cum_log_probs is not None - current_logprobs, current_logprobs_indices = self._get_logprobs_from_request(request) - # concatenate the newly generated logprobs and newly - # generated tokens to the current logprobs and logprobs indices - current_logprobs = torch.cat( - [ - current_logprobs, - self.store.sampled_log_probs[request.py_seq_slot, :num_beams].view(-1, 1, 1), - ], - dim=1, - ) - current_logprobs_indices = torch.cat( - [ - current_logprobs_indices, - self.store.new_tokens[0, request.py_seq_slot, :num_beams].view(-1, 1, 1), - ], - dim=1, - ) - # Initialize the buffers to store the results - new_logprobs = torch.zeros_like(current_logprobs) - new_logprobs_indices = torch.zeros_like(current_logprobs_indices) - - cache_indirection_for_logprobs = cache_indirection.unsqueeze(-1).expand( - -1, -1, current_logprobs.shape[2] - ) - torch.gather( - input=current_logprobs, - dim=0, - index=cache_indirection_for_logprobs, - out=new_logprobs, - ) - torch.gather( - input=current_logprobs_indices, - dim=0, - index=cache_indirection_for_logprobs, - out=new_logprobs_indices, + sampled_log_probs = self.store.sampled_log_probs[request.py_seq_slot, :num_beams].view( + -1, 1 ) + sampled_logprobs_indices = self.store.new_tokens[ + 0, request.py_seq_slot, :num_beams + ].view(-1, 1) cum_logprobs = self.store.cum_log_probs[request.py_seq_slot, :num_beams] + + # enqueue async D2H copies + sampled_log_probs = self._copy_to_host(sampled_log_probs) + sampled_logprobs_indices = self._copy_to_host(sampled_logprobs_indices) + cum_logprobs = self._copy_to_host(cum_logprobs) + + def _maybe_postprocess_logprobs() -> tuple[ + torch.Tensor | None, torch.Tensor | None, torch.Tensor | None + ]: + # Gather the correct logprobs for each beam + + current_logprobs, current_logprobs_indices = self._get_logprobs_from_request( + request, preallocate_extra_steps=1 + ) + # concatenate the newly generated logprobs and newly + # generated tokens to the current logprobs and logprobs indices + current_logprobs[:, -1, :].copy_(sampled_log_probs) + current_logprobs_indices[:, -1, :].copy_(sampled_logprobs_indices) + + # Initialize the buffers to store the results + new_logprobs = torch.zeros_like(current_logprobs) + new_logprobs_indices = torch.zeros_like(current_logprobs_indices) + + cache_indirection_for_logprobs = cache_indirection.unsqueeze(-1).expand( + -1, -1, current_logprobs.shape[2] + ) + torch.gather( + input=current_logprobs, + dim=0, + index=cache_indirection_for_logprobs, + out=new_logprobs, + ) + torch.gather( + input=current_logprobs_indices, + dim=0, + index=cache_indirection_for_logprobs, + out=new_logprobs_indices, + ) + return new_logprobs, new_logprobs_indices, cum_logprobs + + else: + + def _maybe_postprocess_logprobs() -> tuple[ + torch.Tensor | None, torch.Tensor | None, torch.Tensor | None + ]: + return None, None, None + + def _builder() -> BeamHistory | None: + if not need_history.item(): + return None + + new_path = _post_process_path() + new_logprobs, new_logprobs_indices, cum_logprobs = _maybe_postprocess_logprobs() + return BeamHistory( tokens=new_path, logprobs=new_logprobs, logprobs_indices=new_logprobs_indices, cum_logprobs=cum_logprobs, ) - else: - return BeamHistory( - tokens=new_path, - logprobs=None, - logprobs_indices=None, - cum_logprobs=None, - ) + + return _builder def _finalize_beam( self, @@ -1897,23 +1928,19 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin): f"Beam_history.cum_logprobs.shape[0] should equal beam width: \ {beam_history.cum_logprobs.shape[0]} != {beam_width}" ) - valid_tokens = (beam_history.tokens != BEAM_SEARCH_PAD_TOKEN).sum(dim=-1) + valid_tokens = (beam_history.tokens != BEAM_SEARCH_PAD_TOKEN).sum(dim=-1).tolist() gen_token_list = [] gen_log_probs_list = [] for beam_idx in range(beam_width): - gen_token_list.append(beam_history.tokens[beam_idx, : valid_tokens[beam_idx]].tolist()) + beam_valid_tokens = valid_tokens[beam_idx] + gen_token_list.append(beam_history.tokens[beam_idx, :beam_valid_tokens].tolist()) if request.py_return_log_probs: assert beam_history.logprobs_indices is not None assert beam_history.logprobs is not None gen_log_probs_list.append( self._convert_logprobs_tensor_to_list( - beam_history.logprobs_indices[ - beam_idx : beam_idx + 1, : valid_tokens[beam_idx] - ], - beam_history.logprobs[beam_idx : beam_idx + 1, : valid_tokens[beam_idx]], - None, - None, - None, + beam_history.logprobs_indices[beam_idx : beam_idx + 1, :beam_valid_tokens], + beam_history.logprobs[beam_idx : beam_idx + 1, :beam_valid_tokens], )[0] ) request.set_generated_tokens(gen_token_list) @@ -1989,11 +2016,14 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin): self, request: LlmRequest, finish_reasons: torch.Tensor, - ) -> bool: - """Check if the stop criteria is met for the request""" + ) -> torch.Tensor: + """Check if the stop criteria is met for the request. + + Returns a boolean tensor of shape (), whose value is computed asynchronously. + """ return ( finish_reasons[: request.sampling_config.beam_width] > 0 - ).sum().item() == request.sampling_config.beam_width # NB: This syncs + ).sum() == request.sampling_config.beam_width def _check_stop_words_length(self, request: LlmRequest) -> bool: """Check if the stop words length is greater than 1""" @@ -2006,23 +2036,20 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin): return False @nvtx_range("maybe_create_beam_histories") - def _maybe_create_beam_histories( + def _prepare_beam_histories( self, requests: list[LlmRequest], finish_reasons: torch.Tensor, - beam_histories: list[BeamHistory | None], - ) -> None: - """Create the corrected tokens and logprobs for each beam of a request + ) -> list[BeamHistoryBuilder | None]: + """Create the corrected tokens and logprobs for each beam of a request. - This function creates a beam history object containing the corrected - tokens and logprobs for each beam of a request""" - for req_idx, req in enumerate(requests): - should_stop = self._check_beam_search_stop_criteria( - req, finish_reasons=finish_reasons[req.py_seq_slot] - ) - need_finalize_due_to_stop_words = self._check_stop_words_length(req) - if should_stop or req.streaming or need_finalize_due_to_stop_words: - beam_histories[req_idx] = self._create_beam_history(req) + The builders returned by this function create a beam history object containing + the corrected tokens and logprobs for each beam of a request. + """ + return [ + self._prepare_beam_history(req, finish_reasons=finish_reasons[req.py_seq_slot]) + for req in requests + ] @override @nvtx_range("update_requests") @@ -2040,22 +2067,31 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin): finish_reasons = state.host.finish_reasons_list() new_tokens_list = new_tokens.tolist() - beam_histories = state.beam_histories + logprobs_state_list: LogProbsStateList | None = None if state.host.logprobs_state is not None: logprobs_state_list = LogProbsStateList.from_logprobs_state(state.host.logprobs_state) + beam_history_builders = state.beam_history_builders + assert (beam_history_builders is not None) == self._use_beam_search + + def _maybe_build_beam_history(req_idx: int) -> BeamHistory | None: + if ( + beam_history_builders is not None + and (beam_history_builder := beam_history_builders[req_idx]) is not None + ): + return beam_history_builder() + else: + return None + for req_idx, req in enumerate(state.scheduled_requests.context_requests): if ( req.state == LlmRequestState.GENERATION_COMPLETE or req.context_remaining_length != 0 ): continue - if beam_histories is not None and beam_histories[req_idx] is not None: - self._finalize_beam( - req, - cast(BeamHistory, beam_histories[req_idx]), - ) + if (beam_history := _maybe_build_beam_history(req_idx)) is not None: + self._finalize_beam(req, beam_history) else: for beam_idx in range(req.sampling_config.beam_width): add_token(req, new_tokens_list, beam_idx=beam_idx) @@ -2069,12 +2105,10 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin): ): if req.state == LlmRequestState.GENERATION_COMPLETE: continue + if req.sampling_config.beam_width > 1: - if beam_histories is not None and beam_histories[req_idx] is not None: - self._finalize_beam( - req, - cast(BeamHistory, beam_histories[req_idx]), - ) + if (beam_history := _maybe_build_beam_history(req_idx)) is not None: + self._finalize_beam(req, beam_history) else: for beam_idx in range(req.sampling_config.beam_width): # Beam search does not support speculative decoding. @@ -2083,7 +2117,6 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin): self._handle_finish_reasons(req, state.host.finish_reasons, finish_reasons) req.py_num_accepted_draft_tokens = 0 req.py_rewind_len = 0 - else: processed = 1 num_accepted = self.process_draft_tokens( @@ -2175,7 +2208,7 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin): ) finish_reasons_host = self._copy_to_host(finish_reasons) - beam_histories: list[BeamHistory | None] = [None] * len(requests) + beam_history_builders = None if self._use_beam_search: assert first_finish_reasons is not None assert seq_lens_host is not None, "seq_lens is required for beam search" @@ -2185,8 +2218,8 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin): seq_lens = seq_lens_host.to(device="cuda", non_blocking=True) first_finish_reasons_host = self._copy_to_host(self.store.first_finish_reasons) self._update_original_tokens(seq_slots, seq_lens, new_tokens) - self._maybe_create_beam_histories( - requests, finish_reasons=first_finish_reasons, beam_histories=beam_histories + beam_history_builders = self._prepare_beam_histories( + requests, finish_reasons=first_finish_reasons ) else: first_finish_reasons_host = None @@ -2231,7 +2264,7 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin): logprobs_state=logprobs_state, ), sampler_event=sampler_event, - beam_histories=beam_histories, + beam_history_builders=beam_history_builders, ) @staticmethod @@ -2885,12 +2918,11 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin): if first_finish_reasons is not None: # store the first stop reason for each beam of a seq_slot. batched_first_finish_reasons = first_finish_reasons[seq_slots] - batched_first_finish_reasons = torch.where( + first_finish_reasons[seq_slots, ...] = torch.where( batched_first_finish_reasons == FinishReason.NOT_FINISHED.value, batched_finish_reasons, batched_first_finish_reasons, ) - first_finish_reasons[seq_slots] = batched_first_finish_reasons def _are_end_id(self, end_ids: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor: return tokens == end_ids.view(1, -1, 1).expand(self.max_tokens, -1, self.max_beam_width) diff --git a/tests/unittest/_torch/sampler/test_beam_search.py b/tests/unittest/_torch/sampler/test_beam_search.py index 1489b3c564..169fb0a6de 100644 --- a/tests/unittest/_torch/sampler/test_beam_search.py +++ b/tests/unittest/_torch/sampler/test_beam_search.py @@ -635,9 +635,6 @@ def test_create_beam_history(): token_logprobs = sampler._convert_logprobs_tensor_to_list( original_logprob_indices[:beam_width, :num_generated_tokens - 1], original_logprobs[:beam_width, :num_generated_tokens - 1], - None, - None, - None, ) request.py_result.set_log_probs( token_logprobs, @@ -670,18 +667,18 @@ def test_create_beam_history(): num_generated_tokens - 1, 0] # test - beam_history = sampler._create_beam_history(request) + beam_history_builder = sampler._prepare_beam_history( + request, finish_reasons=torch.ones((beam_width, ), dtype=torch.int)) + torch.cuda.synchronize() + beam_history = beam_history_builder() # expected selection: # Currently beam history only contains the generated tokens, not the prompt tokens. expected_tokens = torch.zeros( - (sampler.max_beam_width, num_generated_tokens), - dtype=torch.int32, - device=original_tokens.device) + (sampler.max_beam_width, num_generated_tokens), dtype=torch.int32) expected_logprobs = torch.zeros( (beam_width, num_generated_tokens, original_logprobs.shape[-1]), - dtype=torch.float32, - device=original_logprobs.device) + dtype=torch.float32) for gen_idx in range(num_generated_tokens): token_idx = prompt_len + gen_idx expected_tokens[:, gen_idx] = original_tokens[ @@ -694,10 +691,9 @@ def test_create_beam_history(): # test logprobs as well torch.testing.assert_close(beam_history.logprobs[:beam_width], expected_logprobs[:beam_width]) - torch.testing.assert_close(beam_history.cum_logprobs[:beam_width], - original_cum_logprobs[seq_slot, :beam_width]) - - return + torch.testing.assert_close( + beam_history.cum_logprobs[:beam_width], + original_cum_logprobs[seq_slot, :beam_width].to("cpu")) def test_finish_beams():