mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-07 11:41:47 +08:00
[TRTLLM-4995][feat] TRTLLM Sampler log probs support (#4836)
Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>
This commit is contained in:
parent
00991d1520
commit
fdf1c47d1d
@ -118,15 +118,24 @@ class LogProbStorage:
|
||||
self.log_probs = [[] for _ in range(self.beam_width)]
|
||||
self.cum_log_probs = [0 for _ in range(self.beam_width)]
|
||||
|
||||
def append(self, new_probs: list[TokenLogprobs]):
|
||||
def append(self,
|
||||
new_probs: list[TokenLogprobs],
|
||||
cum_log_probs: Optional[list[float]] = None):
|
||||
"""
|
||||
new_probs: [beam_width, num_tokens]
|
||||
cum_log_probs: [beam_width]
|
||||
"""
|
||||
if self.beam_width == -1:
|
||||
self._init(new_probs)
|
||||
|
||||
assert len(new_probs) == self.beam_width, "Beam width mismatch"
|
||||
for idx, probs in enumerate(new_probs):
|
||||
self.log_probs[idx].extend(probs)
|
||||
self.cum_log_probs[idx] += sum(
|
||||
next(iter(prob.values())).logprob for prob in probs)
|
||||
for beam_idx, probs in enumerate(new_probs):
|
||||
self.log_probs[beam_idx].extend(probs)
|
||||
if cum_log_probs is not None:
|
||||
self.cum_log_probs[beam_idx] = cum_log_probs[beam_idx]
|
||||
else:
|
||||
self.cum_log_probs[beam_idx] += sum(
|
||||
next(iter(prob.values())).logprob for prob in probs)
|
||||
|
||||
|
||||
class PyResult:
|
||||
@ -157,9 +166,11 @@ class PyResult:
|
||||
if self._generation_logits:
|
||||
self._generation_logits.append(generation_logits)
|
||||
|
||||
def append_log_probs(self, log_probs: list[TokenLogprobs]):
|
||||
def append_log_probs(self,
|
||||
log_probs: list[TokenLogprobs],
|
||||
cum_log_probs: Optional[list[float]] = None):
|
||||
if self._log_probs:
|
||||
self._log_probs.append(log_probs)
|
||||
self._log_probs.append(log_probs, cum_log_probs)
|
||||
|
||||
@property
|
||||
def context_logits(self) -> torch.Tensor | None:
|
||||
@ -250,7 +261,7 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
|
||||
super().__init__(
|
||||
*args,
|
||||
client_id=client_id,
|
||||
return_log_probs=False,
|
||||
return_log_probs=return_log_probs,
|
||||
return_context_logits=False,
|
||||
return_generation_logits=False,
|
||||
stop_words_list=torch.tensor(stop_words_list, dtype=torch.int32)
|
||||
|
||||
@ -460,6 +460,8 @@ class SampleStateTensorsHostTRTLLM(SampleStateTensors):
|
||||
finished_sum: torch.Tensor
|
||||
finish_reasons: torch.Tensor
|
||||
sequence_lengths: torch.Tensor
|
||||
log_probs: torch.Tensor
|
||||
cum_log_probs: torch.Tensor
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
@ -650,12 +652,23 @@ class TRTLLMSampler(Sampler):
|
||||
sequence_lengths = self.algs.decoder_state.sequence_lengths.to(
|
||||
'cpu', non_blocking=True)
|
||||
|
||||
log_probs = torch.empty([0], dtype=torch.float, device='cpu')
|
||||
cum_log_probs = torch.empty([0], dtype=torch.float, device='cpu')
|
||||
if any(request.py_return_log_probs
|
||||
for request in scheduled_requests.all_requests):
|
||||
log_probs = self.algs.decoder_state.log_probs.to('cpu',
|
||||
non_blocking=True)
|
||||
cum_log_probs = self.algs.decoder_state.cum_log_probs.to(
|
||||
'cpu', non_blocking=True)
|
||||
|
||||
device = SampleStateTensors(new_tokens=new_tokens_device_tensor)
|
||||
|
||||
host = SampleStateTensorsHostTRTLLM(new_tokens=new_output_tokens,
|
||||
finished_sum=finished_sum,
|
||||
finish_reasons=finish_reasons,
|
||||
sequence_lengths=sequence_lengths)
|
||||
sequence_lengths=sequence_lengths,
|
||||
log_probs=log_probs,
|
||||
cum_log_probs=cum_log_probs)
|
||||
|
||||
sampler_event = torch.cuda.Event()
|
||||
sampler_event.record()
|
||||
@ -691,6 +704,9 @@ class TRTLLMSampler(Sampler):
|
||||
current_num_of_tokens = request.max_beam_num_tokens
|
||||
num_new_tokens = [0] * beam_width
|
||||
|
||||
log_probs = []
|
||||
cum_log_probs = []
|
||||
|
||||
for beam in range(beam_width):
|
||||
seq_len = sequence_lengths_host_data[seq_slot * beam_width +
|
||||
beam].item()
|
||||
@ -702,10 +718,32 @@ class TRTLLMSampler(Sampler):
|
||||
new_token = new_tokens_host[step][seq_slot][beam]
|
||||
request.add_new_token(new_token, beam)
|
||||
|
||||
if request.py_return_log_probs:
|
||||
# NOTE: Log probs with drafting has not been tested yet.
|
||||
begin_log_probs_offset = request.prompt_len if request.sampling_config.beam_width == 1 else 0
|
||||
current_token = seq_len - request.prompt_len - len(
|
||||
num_new_tokens[beam]) + step
|
||||
|
||||
log_probs.append({
|
||||
new_token.item():
|
||||
Logprob(logprob=state.host.log_probs[seq_slot][beam]
|
||||
[begin_log_probs_offset +
|
||||
current_token].item(),
|
||||
rank=1)
|
||||
})
|
||||
|
||||
if num_new_tokens[beam] > 0 and request.py_return_log_probs:
|
||||
cum_log_probs.append(
|
||||
state.host.cum_log_probs[seq_slot * beam_width +
|
||||
beam].item())
|
||||
|
||||
finish_reason = finish_reasons_host[seq_slot * beam_width +
|
||||
beam].item()
|
||||
request.set_finished_reason(FinishReason(finish_reason), beam)
|
||||
|
||||
if request.py_return_log_probs:
|
||||
request.py_result.append_log_probs([log_probs], cum_log_probs)
|
||||
|
||||
# Set number of tokens predicted per runtime iteration. Will be > 1 for speculative decoding.
|
||||
request.update_num_tokens_per_iteration(
|
||||
request.max_beam_num_tokens - current_num_of_tokens,
|
||||
|
||||
@ -70,9 +70,7 @@ def test_generate_with_return_logits(disable_overlap_scheduler: bool,
|
||||
or return_log_probs): # prune space
|
||||
pytest.skip("Nothing to test")
|
||||
|
||||
if enable_trtllm_sampler and return_log_probs:
|
||||
pytest.skip("TRTLLMSampler does not support return_log_probs")
|
||||
elif not enable_trtllm_sampler and gather_context_logits:
|
||||
if not enable_trtllm_sampler and gather_context_logits:
|
||||
pytest.skip("TorchSampler does not support gather_context_logits")
|
||||
|
||||
build_config = BuildConfig()
|
||||
@ -141,9 +139,7 @@ def test_generate_async_with_return_logits(disable_overlap_scheduler: bool,
|
||||
or return_log_probs): # prune space
|
||||
pytest.skip("Nothing to test")
|
||||
|
||||
if enable_trtllm_sampler and return_log_probs:
|
||||
pytest.skip("TRTLLMSampler does not support return_log_probs")
|
||||
elif not enable_trtllm_sampler and gather_context_logits:
|
||||
if not enable_trtllm_sampler and gather_context_logits:
|
||||
pytest.skip("TorchSampler does not support gather_context_logits")
|
||||
|
||||
build_config = BuildConfig()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user