From fdf1c47d1d409cf7be8abf13652ac438c32897a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20C=C3=A1mpora?= <961215+dcampora@users.noreply.github.com> Date: Wed, 11 Jun 2025 08:18:13 +0200 Subject: [PATCH] [TRTLLM-4995][feat] TRTLLM Sampler log probs support (#4836) Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/llm_request.py | 27 +++++++++---- tensorrt_llm/_torch/pyexecutor/sampler.py | 40 ++++++++++++++++++- tests/unittest/_torch/test_return_logits.py | 8 +--- 3 files changed, 60 insertions(+), 15 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index 3b8baaf88e..680206c619 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -118,15 +118,24 @@ class LogProbStorage: self.log_probs = [[] for _ in range(self.beam_width)] self.cum_log_probs = [0 for _ in range(self.beam_width)] - def append(self, new_probs: list[TokenLogprobs]): + def append(self, + new_probs: list[TokenLogprobs], + cum_log_probs: Optional[list[float]] = None): + """ + new_probs: [beam_width, num_tokens] + cum_log_probs: [beam_width] + """ if self.beam_width == -1: self._init(new_probs) assert len(new_probs) == self.beam_width, "Beam width mismatch" - for idx, probs in enumerate(new_probs): - self.log_probs[idx].extend(probs) - self.cum_log_probs[idx] += sum( - next(iter(prob.values())).logprob for prob in probs) + for beam_idx, probs in enumerate(new_probs): + self.log_probs[beam_idx].extend(probs) + if cum_log_probs is not None: + self.cum_log_probs[beam_idx] = cum_log_probs[beam_idx] + else: + self.cum_log_probs[beam_idx] += sum( + next(iter(prob.values())).logprob for prob in probs) class PyResult: @@ -157,9 +166,11 @@ class PyResult: if self._generation_logits: self._generation_logits.append(generation_logits) - def append_log_probs(self, log_probs: list[TokenLogprobs]): + def append_log_probs(self, + log_probs: list[TokenLogprobs], + cum_log_probs: Optional[list[float]] = None): if self._log_probs: - self._log_probs.append(log_probs) + self._log_probs.append(log_probs, cum_log_probs) @property def context_logits(self) -> torch.Tensor | None: @@ -250,7 +261,7 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest): super().__init__( *args, client_id=client_id, - return_log_probs=False, + return_log_probs=return_log_probs, return_context_logits=False, return_generation_logits=False, stop_words_list=torch.tensor(stop_words_list, dtype=torch.int32) diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index f0254dd46d..85cbb9da36 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -460,6 +460,8 @@ class SampleStateTensorsHostTRTLLM(SampleStateTensors): finished_sum: torch.Tensor finish_reasons: torch.Tensor sequence_lengths: torch.Tensor + log_probs: torch.Tensor + cum_log_probs: torch.Tensor @dataclass(kw_only=True) @@ -650,12 +652,23 @@ class TRTLLMSampler(Sampler): sequence_lengths = self.algs.decoder_state.sequence_lengths.to( 'cpu', non_blocking=True) + log_probs = torch.empty([0], dtype=torch.float, device='cpu') + cum_log_probs = torch.empty([0], dtype=torch.float, device='cpu') + if any(request.py_return_log_probs + for request in scheduled_requests.all_requests): + log_probs = self.algs.decoder_state.log_probs.to('cpu', + non_blocking=True) + cum_log_probs = self.algs.decoder_state.cum_log_probs.to( + 'cpu', non_blocking=True) + device = SampleStateTensors(new_tokens=new_tokens_device_tensor) host = SampleStateTensorsHostTRTLLM(new_tokens=new_output_tokens, finished_sum=finished_sum, finish_reasons=finish_reasons, - sequence_lengths=sequence_lengths) + sequence_lengths=sequence_lengths, + log_probs=log_probs, + cum_log_probs=cum_log_probs) sampler_event = torch.cuda.Event() sampler_event.record() @@ -691,6 +704,9 @@ class TRTLLMSampler(Sampler): current_num_of_tokens = request.max_beam_num_tokens num_new_tokens = [0] * beam_width + log_probs = [] + cum_log_probs = [] + for beam in range(beam_width): seq_len = sequence_lengths_host_data[seq_slot * beam_width + beam].item() @@ -702,10 +718,32 @@ class TRTLLMSampler(Sampler): new_token = new_tokens_host[step][seq_slot][beam] request.add_new_token(new_token, beam) + if request.py_return_log_probs: + # NOTE: Log probs with drafting has not been tested yet. + begin_log_probs_offset = request.prompt_len if request.sampling_config.beam_width == 1 else 0 + current_token = seq_len - request.prompt_len - len( + num_new_tokens[beam]) + step + + log_probs.append({ + new_token.item(): + Logprob(logprob=state.host.log_probs[seq_slot][beam] + [begin_log_probs_offset + + current_token].item(), + rank=1) + }) + + if num_new_tokens[beam] > 0 and request.py_return_log_probs: + cum_log_probs.append( + state.host.cum_log_probs[seq_slot * beam_width + + beam].item()) + finish_reason = finish_reasons_host[seq_slot * beam_width + beam].item() request.set_finished_reason(FinishReason(finish_reason), beam) + if request.py_return_log_probs: + request.py_result.append_log_probs([log_probs], cum_log_probs) + # Set number of tokens predicted per runtime iteration. Will be > 1 for speculative decoding. request.update_num_tokens_per_iteration( request.max_beam_num_tokens - current_num_of_tokens, diff --git a/tests/unittest/_torch/test_return_logits.py b/tests/unittest/_torch/test_return_logits.py index de9917122c..555d4e9d3c 100644 --- a/tests/unittest/_torch/test_return_logits.py +++ b/tests/unittest/_torch/test_return_logits.py @@ -70,9 +70,7 @@ def test_generate_with_return_logits(disable_overlap_scheduler: bool, or return_log_probs): # prune space pytest.skip("Nothing to test") - if enable_trtllm_sampler and return_log_probs: - pytest.skip("TRTLLMSampler does not support return_log_probs") - elif not enable_trtllm_sampler and gather_context_logits: + if not enable_trtllm_sampler and gather_context_logits: pytest.skip("TorchSampler does not support gather_context_logits") build_config = BuildConfig() @@ -141,9 +139,7 @@ def test_generate_async_with_return_logits(disable_overlap_scheduler: bool, or return_log_probs): # prune space pytest.skip("Nothing to test") - if enable_trtllm_sampler and return_log_probs: - pytest.skip("TRTLLMSampler does not support return_log_probs") - elif not enable_trtllm_sampler and gather_context_logits: + if not enable_trtllm_sampler and gather_context_logits: pytest.skip("TorchSampler does not support gather_context_logits") build_config = BuildConfig()