mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* Update TensorRT-LLM --------- Co-authored-by: Denis Kayshev <topenkoff@gmail.com> Co-authored-by: akhoroshev <arthoroshev@gmail.com> Co-authored-by: Patrick Reiter Horn <patrick.horn@gmail.com> Update
390 lines
14 KiB
Python
390 lines
14 KiB
Python
from dataclasses import dataclass, field
|
|
from queue import Queue
|
|
from typing import TYPE_CHECKING, Any, Callable, List, Literal, Optional, Union
|
|
from weakref import WeakMethod
|
|
|
|
import torch
|
|
|
|
from ..bindings import executor as tllm
|
|
from ..llmapi.tracer import global_tracer
|
|
from ..llmapi.utils import AsyncQueue
|
|
from ..sampling_params import SamplingParams
|
|
from .utils import has_event_loop
|
|
|
|
if TYPE_CHECKING:
|
|
from .executor import GenerationExecutor
|
|
from .postproc_worker import PostprocWorker
|
|
from .request import GenerationRequest
|
|
|
|
__all__ = [
|
|
"GenerationResultBase",
|
|
"DetokenizedGenerationResultBase",
|
|
"GenerationResult",
|
|
]
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class CompletionOutput:
|
|
"""The output data of one completion output of a request.
|
|
|
|
Args:
|
|
index (int): The index of the output in the request.
|
|
text (str): The generated output text. Defaults to "".
|
|
token_ids (List[int]): The token ids of the generated output text. Defaults to [].
|
|
cumulative_logprob (float, optional): The cumulative log probability of the generated output text. Defaults to None.
|
|
logprobs (List[float]): The log probabilities of the top probability words at each position if the logprobs are requested. Defaults to [].
|
|
finish_reason (Literal['stop', 'length'], optional): The reason why the sequence is finished. Defaults to None.
|
|
stop_reason (int, str, optional): The stop string or token id that caused the completion to stop, None if the completion finished for some other reason. Defaults to None.
|
|
generation_logits (torch.Tensor, optional): The logits on the generated output token ids. Defaults to None.
|
|
|
|
Properties:
|
|
length (int): The number of generated tokens.
|
|
token_ids_diff (List[int]): Newly generated token ids.
|
|
logprobs_diff (List[float]): Logprobs of newly generated tokens.
|
|
text_diff (str): Newly generated tokens.
|
|
"""
|
|
index: int
|
|
text: str = ""
|
|
token_ids: List[int] = field(default_factory=list)
|
|
cumulative_logprob: Optional[float] = None
|
|
logprobs: List[float] = field(default_factory=list)
|
|
finish_reason: Optional[Literal['stop', 'length']] = None
|
|
stop_reason: Optional[Union[int, str]] = None
|
|
generation_logits: Optional[torch.Tensor] = None
|
|
|
|
# hidden fields for tracking the diffs
|
|
_last_text_len: int = field(default=0, init=False, repr=False)
|
|
_last_token_ids_len: int = field(default=0, init=False, repr=False)
|
|
_last_logprobs_len: int = field(default=0, init=False, repr=False)
|
|
_incremental_states: Optional[dict] = field(default=None,
|
|
init=False,
|
|
repr=False)
|
|
# the result of result_handler passed to postprocess workers
|
|
_postprocess_result: Any = None
|
|
|
|
@property
|
|
def length(self):
|
|
return len(self.token_ids)
|
|
|
|
@property
|
|
def text_diff(self) -> str:
|
|
return self.text[self._last_text_len:]
|
|
|
|
@property
|
|
def token_ids_diff(self) -> List[int]:
|
|
return self.token_ids[self._last_token_ids_len:]
|
|
|
|
@property
|
|
def logprobs_diff(self) -> List[float]:
|
|
return self.logprobs[self._last_logprobs_len:]
|
|
|
|
|
|
class GenerationResultBase:
|
|
''' This holds the core logic of the GenerationResult class. '''
|
|
|
|
def __init__(self,
|
|
id: int,
|
|
sampling_params: SamplingParams,
|
|
background_error_handler: Optional[Callable] = None):
|
|
self.id = id
|
|
self.sampling_params = sampling_params
|
|
self._done = False
|
|
self._cancelled = False
|
|
|
|
if has_event_loop():
|
|
self.aqueue = AsyncQueue()
|
|
self.queue = self.aqueue.sync_q
|
|
else:
|
|
self.queue = Queue()
|
|
self.aqueue = None
|
|
|
|
# In Sampling mode, the Executor runtime will return best_of sequences
|
|
# in total, which the LLM API will select the n-best sequences among
|
|
# them based on their cumulative log probabilities.
|
|
self._outputs: List[CompletionOutput] = [
|
|
CompletionOutput(i) for i in range(self.sampling_params.best_of)
|
|
]
|
|
self.context_logits: Optional[torch.Tensor] = None
|
|
|
|
self._background_error_handler = None
|
|
if background_error_handler is not None:
|
|
if not isinstance(background_error_handler, WeakMethod):
|
|
self._background_error_handler = WeakMethod(
|
|
background_error_handler)
|
|
else:
|
|
self._background_error_handler = background_error_handler
|
|
|
|
# This is used for avoid duplicate transmission the sampling_params for a
|
|
# request. SamplingParams is necessary for creating dummy
|
|
# GenerationResultBase instances on postprocess worker processes.
|
|
self._postproc_sampling_params_transmitted = False
|
|
|
|
@property
|
|
def outputs(self) -> List[CompletionOutput]:
|
|
sampling_param = self.sampling_params
|
|
if (sampling_param.use_beam_search
|
|
or sampling_param.n == sampling_param.best_of):
|
|
return self._outputs[:sampling_param.n]
|
|
# Pick the top-n outputs, sorted by cumulative log probs.
|
|
sorted_outputs = sorted(
|
|
self._outputs,
|
|
key=lambda x:
|
|
(x.cumulative_logprob
|
|
if x.cumulative_logprob is not None else float('-inf')),
|
|
reverse=True)
|
|
# Reindex the sequence.
|
|
for i, sorted_out in enumerate(sorted_outputs):
|
|
sorted_out.index = i
|
|
return sorted_outputs[:sampling_param.n]
|
|
|
|
def handle_sequence(self, response: "GenerationExecutor.Response",
|
|
sequence_index: int):
|
|
""" Handle a single sequence in the response. """
|
|
|
|
tensors = response.tensors
|
|
assert tensors is not None
|
|
|
|
beam_search = self.sampling_params.use_beam_search
|
|
seq_idx = sequence_index
|
|
src_idx = sequence_index if beam_search else 0
|
|
|
|
output = self._outputs[seq_idx]
|
|
|
|
output._last_token_ids_len = len(output.token_ids)
|
|
output.token_ids.extend(tensors.output_token_ids[src_idx])
|
|
if tensors.cum_log_probs is not None:
|
|
output.cumulative_logprob = tensors.cum_log_probs[src_idx]
|
|
if tensors.log_probs is not None:
|
|
output._last_logprobs_len = len(output.logprobs)
|
|
output.logprobs = tensors.log_probs[src_idx]
|
|
assert len(output.logprobs) == output.length
|
|
if tensors.generation_logits is not None:
|
|
output.generation_logits = tensors.generation_logits[
|
|
src_idx, :output.length]
|
|
|
|
if self._done:
|
|
if response.finish_reasons[src_idx] == tllm.FinishReason.END_ID:
|
|
output.finish_reason = 'stop'
|
|
elif response.finish_reasons[
|
|
src_idx] == tllm.FinishReason.STOP_WORDS:
|
|
output.finish_reason = 'stop'
|
|
for stop_reason, stop_ids in self.sampling_params._get_stop_reasons_and_words(
|
|
):
|
|
if output.token_ids[-len(stop_ids):] == stop_ids:
|
|
output.stop_reason = stop_reason
|
|
if not self.sampling_params.include_stop_str_in_output:
|
|
output.token_ids = output.token_ids[:-len(stop_ids)]
|
|
break
|
|
elif response.finish_reasons[src_idx] == tllm.FinishReason.LENGTH:
|
|
output.finish_reason = 'length'
|
|
|
|
def handle_response(self, response: Union["GenerationExecutor.Response",
|
|
"PostprocWorker.Output"]):
|
|
|
|
is_postprocess_res = isinstance(response, PostprocWorker.Output)
|
|
if is_postprocess_res:
|
|
self._done = response.is_final
|
|
if isinstance(response.res, CompletionOutput):
|
|
# in streaming mode
|
|
self._outputs[0] = response.res
|
|
else:
|
|
self._outputs[0]._postprocess_result = response.res
|
|
|
|
self._done = response.is_final
|
|
|
|
if response.error:
|
|
if self._background_error_handler is not None and (
|
|
handler := self._background_error_handler()):
|
|
handler(response.error)
|
|
|
|
if is_postprocess_res: return
|
|
|
|
tensors = response.tensors
|
|
|
|
# output_token_ids = (beams, tokens)
|
|
if self.sampling_params.use_beam_search:
|
|
for beam_idx, _ in enumerate(tensors.output_token_ids):
|
|
self.handle_sequence(response, beam_idx)
|
|
else:
|
|
self.handle_sequence(response, response.sequence_index)
|
|
|
|
if tensors.context_logits is not None:
|
|
self.context_logits = tensors.context_logits
|
|
|
|
# Processing background errors here ASAF during generation.
|
|
if self._background_error_handler and (
|
|
handler := self._background_error_handler()):
|
|
handler()
|
|
|
|
@property
|
|
def done(self) -> bool:
|
|
return self._done
|
|
|
|
|
|
class DetokenizedGenerationResultBase(GenerationResultBase):
|
|
''' The base class for the generation result with detokenization support. '''
|
|
# import once and avoid cyclic import
|
|
from .postproc_worker import PostprocWorker
|
|
|
|
def __init__(self,
|
|
id: int,
|
|
sampling_params: SamplingParams,
|
|
tokenizer: Optional[Callable] = None,
|
|
streaming: bool = False,
|
|
background_error_handler: Optional[Callable] = None):
|
|
super().__init__(id, sampling_params, background_error_handler)
|
|
self.tokenizer = tokenizer
|
|
self._streaming = streaming
|
|
|
|
def handle_response(self, response: "GenerationExecutor.Response"):
|
|
GenerationResultBase.handle_response(self, response)
|
|
|
|
# The postprocess has been performed, return directly
|
|
if isinstance(response, PostprocWorker.Output):
|
|
return
|
|
|
|
kwargs = {
|
|
'skip_special_tokens':
|
|
self.sampling_params.skip_special_tokens,
|
|
'spaces_between_special_tokens':
|
|
self.sampling_params.spaces_between_special_tokens
|
|
}
|
|
if self.sampling_params.detokenize and self.tokenizer is not None:
|
|
for beam_output in self.outputs:
|
|
beam_output._last_text_len = len(beam_output.text)
|
|
if hasattr(self.tokenizer, 'decode_incrementally'):
|
|
if self._streaming and not self.sampling_params.use_beam_search:
|
|
beam_output.text, beam_output._incremental_states = self.tokenizer.decode_incrementally(
|
|
beam_output.token_ids_diff,
|
|
prev_text=beam_output.text,
|
|
states=beam_output._incremental_states,
|
|
flush=self._done,
|
|
**kwargs)
|
|
else:
|
|
beam_output.text, _ = self.tokenizer.decode_incrementally(
|
|
beam_output.token_ids, flush=self._done, **kwargs)
|
|
else:
|
|
beam_output.text = self.tokenizer.decode(
|
|
beam_output.token_ids, **kwargs)
|
|
|
|
|
|
# alias
|
|
PostprocWorker = DetokenizedGenerationResultBase.PostprocWorker
|
|
|
|
|
|
class GenerationResult(GenerationResultBase):
|
|
'''
|
|
The result of a generation request. It can be used to wait for the completion of the request.
|
|
|
|
Args:
|
|
generation_request (GenerationRequest): The generation request object.
|
|
background_error_handler (Callable, optional): The error handler to process the errors from the background threads/processes. Defaults to None.
|
|
'''
|
|
|
|
def __init__(self,
|
|
generation_request: "GenerationRequest",
|
|
background_error_handler: Optional[Callable] = None) -> None:
|
|
super().__init__(generation_request.id,
|
|
generation_request.sampling_params,
|
|
background_error_handler)
|
|
self._generation_request = generation_request
|
|
|
|
@property
|
|
def request_id(self) -> int:
|
|
return self._generation_request.id
|
|
|
|
@property
|
|
def prompt_token_ids(self) -> List[int]:
|
|
return self._generation_request.prompt_token_ids
|
|
|
|
@property
|
|
def finished(self) -> bool:
|
|
return self._done
|
|
|
|
@property
|
|
def streaming(self):
|
|
return self._generation_request.streaming
|
|
|
|
@property
|
|
def generation_request(self) -> "GenerationRequest":
|
|
return self._generation_request
|
|
|
|
def result_step(self, timeout: Optional[float] = None):
|
|
response = self.queue.get(timeout=timeout)
|
|
self.handle_response(response)
|
|
|
|
async def aresult_step(self):
|
|
assert self.aqueue is not None, "The asyncio event loop was not present during initialization, so async operations are not available."
|
|
response = await self.aqueue.get()
|
|
global_tracer().log_instant("result_step.get")
|
|
self.handle_response(response)
|
|
|
|
def result(self, timeout: Optional[float] = None) -> "GenerationResult":
|
|
while not self._done:
|
|
self.result_step(timeout)
|
|
return self
|
|
|
|
async def aresult(self) -> "GenerationResult":
|
|
while not self._done:
|
|
await self.aresult_step()
|
|
return self
|
|
|
|
def __await__(self):
|
|
return self.aresult().__await__()
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def __next__(self):
|
|
if self._done:
|
|
raise StopIteration
|
|
|
|
self.result_step()
|
|
return self
|
|
|
|
def __aiter__(self):
|
|
return self
|
|
|
|
async def __anext__(self):
|
|
if self._done:
|
|
raise StopAsyncIteration
|
|
|
|
await self.aresult_step()
|
|
return self
|
|
|
|
def running(self) -> bool:
|
|
return not self._done
|
|
|
|
def cancelled(self) -> bool:
|
|
return self._cancelled
|
|
|
|
def cancel(self):
|
|
raise NotImplementedError
|
|
|
|
def exception(self, timeout: Optional[float] = None):
|
|
try:
|
|
self.result(timeout)
|
|
except RuntimeError as e:
|
|
return e
|
|
|
|
def _repr_fields(self):
|
|
return [
|
|
'request_id', 'prompt_token_ids', 'outputs', 'finished',
|
|
"context_logits"
|
|
]
|
|
|
|
def __repr__(self) -> str:
|
|
repr = []
|
|
for field in self._repr_fields():
|
|
value = getattr(self, field)
|
|
if isinstance(value, str):
|
|
repr.append(f"{field}={value!r}")
|
|
else:
|
|
repr.append(f"{field}={value}")
|
|
repr = ", ".join(repr)
|
|
repr = f"{self.__class__.__name__}({repr})"
|
|
return repr
|
|
|
|
def __hash__(self):
|
|
return hash(self.request_id)
|