[None][feat] Return logprobs incrementally in torch backend (#8785)

Signed-off-by: Dong Cao <docao@nvidia.com>
This commit is contained in:
Cao Dong 2025-11-07 10:23:39 +08:00 committed by GitHub
parent 9f8d93f89a
commit b53961e972
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 32 additions and 11 deletions

View File

@ -1,4 +1,4 @@
from copy import deepcopy
from copy import copy, deepcopy
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
@ -327,7 +327,8 @@ class PyResult:
@property
def log_probs(self) -> list[TokenLogprobs] | None:
return self._log_probs and self._log_probs.log_probs
return self._log_probs and hasattr(
self._log_probs, 'log_probs') and self._log_probs.log_probs
@property
def cum_log_probs(self) -> list[float] | None:
@ -589,10 +590,21 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
"""
result, is_final = super().create_serialized_result(
use_fast_logits, mpi_world_rank)
# Performs a deep copy of py_result._log_probs to eliminate race conditions that may occur between IPC communication and the overriding of newly generated log_probs in streaming mode.
if self.streaming and self.py_result.log_probs and self.sampling_config.beam_width <= 1:
py_result = copy(self.py_result)
py_result._log_probs = deepcopy(self.py_result._log_probs)
for log_prob in self.py_result.log_probs:
log_prob.clear()
else:
py_result = self.py_result
return LlmResponse(
request_id=self.py_request_id
if self.is_child else self.parent_request_id,
result=LlmResult(result, self.py_result, is_final),
result=LlmResult(result, py_result, is_final),
client_id=self.py_client_id) if len(result) > 0 else None
@property

View File

@ -272,6 +272,8 @@ class GenerationResultBase:
self._done = False
self.metrics_dict = {}
self.trace_headers: Optional[dict[str, str]] = None
# torch backend will use trtllm sampler in beam search mode, but it does not support return logprobs incrementally
self.use_trtllm_sampler = sampling_params.use_beam_search and sampling_params.best_of > 1
if ray_queue is not None:
if has_event_loop():
@ -378,20 +380,27 @@ class GenerationResultBase:
# each streamed response_tensors.log_probs[src_idx]
# contains a streamwise monotonically growing list of logprobs.
# so we need to accumulate only the new ones unique to that particular streamed response
assert output._last_logprobs_len <= len(
response_tensors.log_probs[src_idx]
), (f"_last_logprobs_len ({output._last_logprobs_len}) > log_probs length ("
f"{len(response_tensors.log_probs[src_idx])})")
output.logprobs += response_tensors.log_probs[src_idx][
output._last_logprobs_len:]
if self.use_trtllm_sampler:
assert output._last_logprobs_len <= len(
response_tensors.log_probs[src_idx]
), (f"_last_logprobs_len ({output._last_logprobs_len}) > log_probs length ("
f"{len(response_tensors.log_probs[src_idx])})")
output.logprobs += response_tensors.log_probs[src_idx][
output._last_logprobs_len:]
else:
output.logprobs += response_tensors.log_probs[src_idx]
# overcome some WAR in the cpp executor
if finish_reasons[src_idx] != tllm.FinishReason.CANCELLED:
if finish_reasons[
src_idx] != tllm.FinishReason.CANCELLED and self.use_trtllm_sampler:
# Check if logprobs is a list (not a dict or other structure)
if len(output.logprobs) > output.length:
# LlmResult holds a reference to LogProbStorage, which may be updated by the worker before the result is serialized.
# Therefore, we treat extra logprobs/logits as expected and only consume what's needed.
output.logprobs = output.logprobs[:output.length]
assert len(output.logprobs) == output.length
assert len(
output.logprobs
) == output.length, f"logprobs length: {len(output.logprobs)} != output.length: {output.length}"
if response_tensors.generation_logits is not None:
output.generation_logits = response_tensors.generation_logits[