[TRTLLM-4995][feat] TRTLLM Sampler log probs support (#4836)

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>
This commit is contained in:
Daniel Cámpora 2025-06-11 08:18:13 +02:00 committed by GitHub
parent 00991d1520
commit fdf1c47d1d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 60 additions and 15 deletions

View File

@ -118,15 +118,24 @@ class LogProbStorage:
self.log_probs = [[] for _ in range(self.beam_width)]
self.cum_log_probs = [0 for _ in range(self.beam_width)]
def append(self, new_probs: list[TokenLogprobs]):
def append(self,
new_probs: list[TokenLogprobs],
cum_log_probs: Optional[list[float]] = None):
"""
new_probs: [beam_width, num_tokens]
cum_log_probs: [beam_width]
"""
if self.beam_width == -1:
self._init(new_probs)
assert len(new_probs) == self.beam_width, "Beam width mismatch"
for idx, probs in enumerate(new_probs):
self.log_probs[idx].extend(probs)
self.cum_log_probs[idx] += sum(
next(iter(prob.values())).logprob for prob in probs)
for beam_idx, probs in enumerate(new_probs):
self.log_probs[beam_idx].extend(probs)
if cum_log_probs is not None:
self.cum_log_probs[beam_idx] = cum_log_probs[beam_idx]
else:
self.cum_log_probs[beam_idx] += sum(
next(iter(prob.values())).logprob for prob in probs)
class PyResult:
@ -157,9 +166,11 @@ class PyResult:
if self._generation_logits:
self._generation_logits.append(generation_logits)
def append_log_probs(self, log_probs: list[TokenLogprobs]):
def append_log_probs(self,
log_probs: list[TokenLogprobs],
cum_log_probs: Optional[list[float]] = None):
if self._log_probs:
self._log_probs.append(log_probs)
self._log_probs.append(log_probs, cum_log_probs)
@property
def context_logits(self) -> torch.Tensor | None:
@ -250,7 +261,7 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
super().__init__(
*args,
client_id=client_id,
return_log_probs=False,
return_log_probs=return_log_probs,
return_context_logits=False,
return_generation_logits=False,
stop_words_list=torch.tensor(stop_words_list, dtype=torch.int32)

View File

@ -460,6 +460,8 @@ class SampleStateTensorsHostTRTLLM(SampleStateTensors):
finished_sum: torch.Tensor
finish_reasons: torch.Tensor
sequence_lengths: torch.Tensor
log_probs: torch.Tensor
cum_log_probs: torch.Tensor
@dataclass(kw_only=True)
@ -650,12 +652,23 @@ class TRTLLMSampler(Sampler):
sequence_lengths = self.algs.decoder_state.sequence_lengths.to(
'cpu', non_blocking=True)
log_probs = torch.empty([0], dtype=torch.float, device='cpu')
cum_log_probs = torch.empty([0], dtype=torch.float, device='cpu')
if any(request.py_return_log_probs
for request in scheduled_requests.all_requests):
log_probs = self.algs.decoder_state.log_probs.to('cpu',
non_blocking=True)
cum_log_probs = self.algs.decoder_state.cum_log_probs.to(
'cpu', non_blocking=True)
device = SampleStateTensors(new_tokens=new_tokens_device_tensor)
host = SampleStateTensorsHostTRTLLM(new_tokens=new_output_tokens,
finished_sum=finished_sum,
finish_reasons=finish_reasons,
sequence_lengths=sequence_lengths)
sequence_lengths=sequence_lengths,
log_probs=log_probs,
cum_log_probs=cum_log_probs)
sampler_event = torch.cuda.Event()
sampler_event.record()
@ -691,6 +704,9 @@ class TRTLLMSampler(Sampler):
current_num_of_tokens = request.max_beam_num_tokens
num_new_tokens = [0] * beam_width
log_probs = []
cum_log_probs = []
for beam in range(beam_width):
seq_len = sequence_lengths_host_data[seq_slot * beam_width +
beam].item()
@ -702,10 +718,32 @@ class TRTLLMSampler(Sampler):
new_token = new_tokens_host[step][seq_slot][beam]
request.add_new_token(new_token, beam)
if request.py_return_log_probs:
# NOTE: Log probs with drafting has not been tested yet.
begin_log_probs_offset = request.prompt_len if request.sampling_config.beam_width == 1 else 0
current_token = seq_len - request.prompt_len - len(
num_new_tokens[beam]) + step
log_probs.append({
new_token.item():
Logprob(logprob=state.host.log_probs[seq_slot][beam]
[begin_log_probs_offset +
current_token].item(),
rank=1)
})
if num_new_tokens[beam] > 0 and request.py_return_log_probs:
cum_log_probs.append(
state.host.cum_log_probs[seq_slot * beam_width +
beam].item())
finish_reason = finish_reasons_host[seq_slot * beam_width +
beam].item()
request.set_finished_reason(FinishReason(finish_reason), beam)
if request.py_return_log_probs:
request.py_result.append_log_probs([log_probs], cum_log_probs)
# Set number of tokens predicted per runtime iteration. Will be > 1 for speculative decoding.
request.update_num_tokens_per_iteration(
request.max_beam_num_tokens - current_num_of_tokens,

View File

@ -70,9 +70,7 @@ def test_generate_with_return_logits(disable_overlap_scheduler: bool,
or return_log_probs): # prune space
pytest.skip("Nothing to test")
if enable_trtllm_sampler and return_log_probs:
pytest.skip("TRTLLMSampler does not support return_log_probs")
elif not enable_trtllm_sampler and gather_context_logits:
if not enable_trtllm_sampler and gather_context_logits:
pytest.skip("TorchSampler does not support gather_context_logits")
build_config = BuildConfig()
@ -141,9 +139,7 @@ def test_generate_async_with_return_logits(disable_overlap_scheduler: bool,
or return_log_probs): # prune space
pytest.skip("Nothing to test")
if enable_trtllm_sampler and return_log_probs:
pytest.skip("TRTLLMSampler does not support return_log_probs")
elif not enable_trtllm_sampler and gather_context_logits:
if not enable_trtllm_sampler and gather_context_logits:
pytest.skip("TorchSampler does not support gather_context_logits")
build_config = BuildConfig()