mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][feat] Return logprobs incrementally in torch backend (#8785)
Signed-off-by: Dong Cao <docao@nvidia.com>
This commit is contained in:
parent
9f8d93f89a
commit
b53961e972
@ -1,4 +1,4 @@
|
||||
from copy import deepcopy
|
||||
from copy import copy, deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
@ -327,7 +327,8 @@ class PyResult:
|
||||
|
||||
@property
|
||||
def log_probs(self) -> list[TokenLogprobs] | None:
|
||||
return self._log_probs and self._log_probs.log_probs
|
||||
return self._log_probs and hasattr(
|
||||
self._log_probs, 'log_probs') and self._log_probs.log_probs
|
||||
|
||||
@property
|
||||
def cum_log_probs(self) -> list[float] | None:
|
||||
@ -589,10 +590,21 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
|
||||
"""
|
||||
result, is_final = super().create_serialized_result(
|
||||
use_fast_logits, mpi_world_rank)
|
||||
|
||||
# Performs a deep copy of py_result._log_probs to eliminate race conditions that may occur between IPC communication and the overriding of newly generated log_probs in streaming mode.
|
||||
if self.streaming and self.py_result.log_probs and self.sampling_config.beam_width <= 1:
|
||||
py_result = copy(self.py_result)
|
||||
py_result._log_probs = deepcopy(self.py_result._log_probs)
|
||||
|
||||
for log_prob in self.py_result.log_probs:
|
||||
log_prob.clear()
|
||||
else:
|
||||
py_result = self.py_result
|
||||
|
||||
return LlmResponse(
|
||||
request_id=self.py_request_id
|
||||
if self.is_child else self.parent_request_id,
|
||||
result=LlmResult(result, self.py_result, is_final),
|
||||
result=LlmResult(result, py_result, is_final),
|
||||
client_id=self.py_client_id) if len(result) > 0 else None
|
||||
|
||||
@property
|
||||
|
||||
@ -272,6 +272,8 @@ class GenerationResultBase:
|
||||
self._done = False
|
||||
self.metrics_dict = {}
|
||||
self.trace_headers: Optional[dict[str, str]] = None
|
||||
# torch backend will use trtllm sampler in beam search mode, but it does not support return logprobs incrementally
|
||||
self.use_trtllm_sampler = sampling_params.use_beam_search and sampling_params.best_of > 1
|
||||
|
||||
if ray_queue is not None:
|
||||
if has_event_loop():
|
||||
@ -378,20 +380,27 @@ class GenerationResultBase:
|
||||
# each streamed response_tensors.log_probs[src_idx]
|
||||
# contains a streamwise monotonically growing list of logprobs.
|
||||
# so we need to accumulate only the new ones unique to that particular streamed response
|
||||
assert output._last_logprobs_len <= len(
|
||||
response_tensors.log_probs[src_idx]
|
||||
), (f"_last_logprobs_len ({output._last_logprobs_len}) > log_probs length ("
|
||||
f"{len(response_tensors.log_probs[src_idx])})")
|
||||
output.logprobs += response_tensors.log_probs[src_idx][
|
||||
output._last_logprobs_len:]
|
||||
if self.use_trtllm_sampler:
|
||||
assert output._last_logprobs_len <= len(
|
||||
response_tensors.log_probs[src_idx]
|
||||
), (f"_last_logprobs_len ({output._last_logprobs_len}) > log_probs length ("
|
||||
f"{len(response_tensors.log_probs[src_idx])})")
|
||||
output.logprobs += response_tensors.log_probs[src_idx][
|
||||
output._last_logprobs_len:]
|
||||
else:
|
||||
output.logprobs += response_tensors.log_probs[src_idx]
|
||||
|
||||
# overcome some WAR in the cpp executor
|
||||
if finish_reasons[src_idx] != tllm.FinishReason.CANCELLED:
|
||||
if finish_reasons[
|
||||
src_idx] != tllm.FinishReason.CANCELLED and self.use_trtllm_sampler:
|
||||
# Check if logprobs is a list (not a dict or other structure)
|
||||
if len(output.logprobs) > output.length:
|
||||
# LlmResult holds a reference to LogProbStorage, which may be updated by the worker before the result is serialized.
|
||||
# Therefore, we treat extra logprobs/logits as expected and only consume what's needed.
|
||||
output.logprobs = output.logprobs[:output.length]
|
||||
assert len(output.logprobs) == output.length
|
||||
assert len(
|
||||
output.logprobs
|
||||
) == output.length, f"logprobs length: {len(output.logprobs)} != output.length: {output.length}"
|
||||
|
||||
if response_tensors.generation_logits is not None:
|
||||
output.generation_logits = response_tensors.generation_logits[
|
||||
|
||||
Loading…
Reference in New Issue
Block a user