mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-06 19:21:52 +08:00
[TRTLLM-7155][feat] Unify sampler handle logits implementation. (#6867)
Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>
This commit is contained in:
parent
983dd7e57c
commit
099f081e03
@ -16,7 +16,6 @@ from tensorrt_llm.mapping import CpType
|
||||
from ..distributed import Distributed
|
||||
from .llm_request import (ExecutorRequest, LlmRequest,
|
||||
executor_request_to_llm_request)
|
||||
from .sampler import Sampler, TorchSampler
|
||||
|
||||
SHUTDOWN_REQUEST_ID = -1
|
||||
|
||||
@ -707,21 +706,19 @@ class ExecutorRequestQueue:
|
||||
|
||||
def set_exclude_last_generation_logits(self,
|
||||
disable_overlap_scheduler: bool,
|
||||
sampler: Sampler) -> None:
|
||||
pp_size: int) -> None:
|
||||
# When overlap scheduler is enabled then when starting to handle a new prompt,
|
||||
# sample_async is called twice before the first call to update_requests:
|
||||
# - 1st time as a context request that handles on the 1st generated token
|
||||
# - 2nd time as a generation request that handles on the 2nd generated token.
|
||||
# and only after these two calls the sampler's update_request method is called.
|
||||
# So in a sampler that works by the expected flow of handling the logits in
|
||||
# sample_async (TorchSampler is an anomaly that instead does that on
|
||||
# update_requests), every update_request doesn't handle the newest token, but one
|
||||
# sample_async, every update_request doesn't handle the newest token, but one
|
||||
# before it. Since all these calls work on the same request object, then its
|
||||
# logits storage contains the logits of both the token update_requests should work
|
||||
# on, and also its next token. Thus, excluding the last generation logits from any
|
||||
# getter is required, when not using TorchSampler.
|
||||
self.should_exclude_last_generation_logits = not disable_overlap_scheduler and not isinstance(
|
||||
sampler, TorchSampler)
|
||||
# getter is required.
|
||||
self.should_exclude_last_generation_logits = not disable_overlap_scheduler and pp_size == 1
|
||||
|
||||
def _should_exclude_last_generation_logits(self) -> bool:
|
||||
return self.should_exclude_last_generation_logits
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from itertools import chain
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
@ -16,9 +17,9 @@ class HandleLogits:
|
||||
context_requests: List[LlmRequest],
|
||||
generation_requests: List[LlmRequest],
|
||||
logits: torch.Tensor,
|
||||
num_context_logits_prefix_sum: List[int],
|
||||
max_num_sequences: int,
|
||||
beam_width: int,
|
||||
num_context_logits_prefix_sum: list[int],
|
||||
is_generation_model: bool,
|
||||
):
|
||||
"""Handles context and generation logits for a batch of requests.
|
||||
|
||||
@ -26,10 +27,24 @@ class HandleLogits:
|
||||
context_requests: List of context requests to process
|
||||
generation_requests: List of generation requests to process
|
||||
logits: Input logits tensor
|
||||
num_context_logits_prefix_sum: Prefix sum of context logits for each request
|
||||
max_num_sequences: Maximum number of sequences to process
|
||||
beam_width: Beam width for the generation requests
|
||||
num_context_logits_prefix_sum: Prefix sum of the logits
|
||||
is_generation_model: Bool containing whether the model is generation or not
|
||||
"""
|
||||
if not any(r.py_return_context_logits or r.py_return_generation_logits
|
||||
for r in chain(context_requests, generation_requests)):
|
||||
return
|
||||
|
||||
if not is_generation_model:
|
||||
for llm_req, logits_temp in zip(context_requests, logits):
|
||||
if logits_temp.ndim == 1:
|
||||
# For BERT: Add axis to be compatible with LogitsStorage
|
||||
# (LogitsStorage will interpret this dim as the prompt_len which
|
||||
# is not relevant for outputting logits of encoder only model).
|
||||
logits_temp = logits_temp.unsqueeze(0)
|
||||
llm_req.py_result.append_context_logits(logits_temp)
|
||||
return
|
||||
|
||||
# Copy logits into decoderBuffers.logits
|
||||
for batch_index, llm_req in enumerate(context_requests):
|
||||
logits_begin = num_context_logits_prefix_sum[batch_index]
|
||||
|
||||
@ -39,6 +39,7 @@ from ..models.modeling_utils import DecoderModelForCausalLM
|
||||
from ..speculative.drafter import Drafter
|
||||
from .executor_request_queue import ExecutorRequestQueue, RequestQueueItem
|
||||
from .guided_decoder import GuidedDecoder
|
||||
from .handle_logits import HandleLogits
|
||||
from .kv_cache_transceiver import KvCacheTransceiver
|
||||
from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState,
|
||||
LlmResponse)
|
||||
@ -244,7 +245,7 @@ class PyExecutor:
|
||||
is_disaggregated=kv_cache_transceiver is not None,
|
||||
)
|
||||
self.executor_request_queue.set_exclude_last_generation_logits(
|
||||
self.disable_overlap_scheduler, self.sampler)
|
||||
self.disable_overlap_scheduler, self.dist.pp_size)
|
||||
|
||||
self.stats_lock = threading.Lock()
|
||||
self.stats = []
|
||||
@ -681,24 +682,6 @@ class PyExecutor:
|
||||
self.response_cv.notify_all()
|
||||
self.shutdown_event.set()
|
||||
|
||||
def _need_return_logits(self, scheduled_requests: ScheduledRequests):
|
||||
for req in scheduled_requests.context_requests:
|
||||
if req.py_return_context_logits:
|
||||
return True
|
||||
for req in scheduled_requests.generation_requests:
|
||||
if req.py_return_generation_logits:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _need_return_log_probs(self, scheduled_requests: ScheduledRequests):
|
||||
for req in scheduled_requests.context_requests:
|
||||
if req.py_return_log_probs:
|
||||
return True
|
||||
for req in scheduled_requests.generation_requests:
|
||||
if req.py_return_log_probs:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _executor_loop_pp(self):
|
||||
logger.debug(f"Starting executor loop for pp_rank {self.dist.pp_rank}")
|
||||
torch.cuda.set_device(self.device_id)
|
||||
@ -790,10 +773,6 @@ class PyExecutor:
|
||||
else:
|
||||
with torch.cuda.nvtx.range("_forward_step_last_pp"):
|
||||
batch_outputs = self._forward_step(scheduled_batch)
|
||||
logits_host = None
|
||||
if self._need_return_logits(scheduled_batch):
|
||||
logits_host = batch_outputs["logits"].to(
|
||||
"cpu", non_blocking=True)
|
||||
if self.kv_cache_transceiver and self.guided_decoder:
|
||||
self.guided_decoder.init_disagg_gen_requests(
|
||||
scheduled_batch)
|
||||
@ -802,7 +781,6 @@ class PyExecutor:
|
||||
|
||||
sample_state = self._sample_async(
|
||||
scheduled_batch, batch_outputs)
|
||||
sample_state.host.logits = logits_host
|
||||
self._update_request_states(scheduled_batch)
|
||||
|
||||
if self.enable_iter_perf_stats:
|
||||
@ -832,18 +810,10 @@ class PyExecutor:
|
||||
torch.cuda.nvtx.range_push(
|
||||
"_handle_new_tokens_inter_pp")
|
||||
# Receive tokens from previous pp rank (w.r.t model forward direction)
|
||||
(
|
||||
logits,
|
||||
sample_state.host,
|
||||
) = self.dist.recv_object(
|
||||
sample_state.host = self.dist.recv_object(
|
||||
src=self.dist.prev_pp_rank,
|
||||
tag=prev_microbatch_id,
|
||||
)
|
||||
if logits is not None:
|
||||
logits_host = torch.from_numpy(logits)
|
||||
sample_state.host.logits = logits_host
|
||||
sample_state.device.logits = logits_host.to(
|
||||
self.device_id)
|
||||
else:
|
||||
torch.cuda.nvtx.range_push("_handle_new_tokens_last_pp")
|
||||
sample_state.sampler_event.synchronize()
|
||||
@ -853,18 +823,9 @@ class PyExecutor:
|
||||
if not self.dist.is_second_last_pp_rank:
|
||||
if self.send_handles[prev_microbatch_id] is not None:
|
||||
self.send_handles[prev_microbatch_id].wait()
|
||||
needs_logits = (
|
||||
self._need_return_logits(scheduled_batch)
|
||||
or (self._need_return_log_probs(scheduled_batch)
|
||||
and sample_state.host.log_probs is not None))
|
||||
serialized_logits = sample_state.host.logits.numpy(
|
||||
) if needs_logits else None
|
||||
self.send_handles[
|
||||
prev_microbatch_id] = self.dist.isend_object(
|
||||
(
|
||||
serialized_logits,
|
||||
sample_state.host,
|
||||
),
|
||||
sample_state.host,
|
||||
dest=self.dist.next_pp_rank,
|
||||
tag=prev_microbatch_id)
|
||||
torch.cuda.nvtx.range_pop()
|
||||
@ -884,6 +845,40 @@ class PyExecutor:
|
||||
previous_batch.scheduled_ctx_reqs)
|
||||
|
||||
self._handle_canceled_requests()
|
||||
|
||||
# If logits were requested last PP rank has to send to first PP rank (who sends responses) the
|
||||
# logits of the requests that have finished.
|
||||
# NOTE: If the rank processing the logits ever becomes the same as
|
||||
# the rank sending the responses, this code can be removed.
|
||||
finished_reqs = [
|
||||
r for r in previous_batch.sample_state.
|
||||
scheduled_requests.all_requests()
|
||||
if r.state == LlmRequestState.GENERATION_COMPLETE
|
||||
and (r.py_return_context_logits
|
||||
or r.py_return_generation_logits)
|
||||
]
|
||||
if self.dist.is_first_pp_rank and len(finished_reqs):
|
||||
finished_reqs_py_results = [
|
||||
r.py_result for r in finished_reqs
|
||||
]
|
||||
finished_reqs_py_results = self.dist.recv_object(
|
||||
src=self.dist.prev_pp_rank,
|
||||
tag=prev_microbatch_id,
|
||||
)
|
||||
for req, py_result in zip(finished_reqs,
|
||||
finished_reqs_py_results):
|
||||
req.py_result = py_result
|
||||
|
||||
elif self.dist.is_last_pp_rank and len(finished_reqs):
|
||||
if self.send_handles[
|
||||
prev_microbatch_id] is not None:
|
||||
self.send_handles[prev_microbatch_id].wait()
|
||||
self.send_handles[
|
||||
prev_microbatch_id] = self.dist.isend_object(
|
||||
[r.py_result for r in finished_reqs],
|
||||
dest=self.dist.next_pp_rank,
|
||||
tag=prev_microbatch_id)
|
||||
|
||||
finished_requests = self._handle_responses()
|
||||
previous_scheduled_batch = previous_batch.sample_state.scheduled_requests
|
||||
self.resource_manager.update_resources(
|
||||
@ -1538,7 +1533,22 @@ class PyExecutor:
|
||||
batch_outputs) -> SampleState | None:
|
||||
try:
|
||||
if batch_outputs is not None:
|
||||
return self.sampler.sample_async(scheduled_batch, batch_outputs)
|
||||
num_context_logits_prefix_sum = [0]
|
||||
prefix_sum = 0
|
||||
for request in scheduled_batch.context_requests:
|
||||
prefix_sum += request.context_chunk_size if request.py_return_context_logits else 1
|
||||
num_context_logits_prefix_sum.append(prefix_sum)
|
||||
|
||||
HandleLogits()(scheduled_batch.context_requests,
|
||||
scheduled_batch.generation_requests,
|
||||
batch_outputs["logits"],
|
||||
self.sampler.beam_width(
|
||||
scheduled_batch.all_requests()),
|
||||
num_context_logits_prefix_sum,
|
||||
self.sampler.is_generation_model())
|
||||
|
||||
return self.sampler.sample_async(scheduled_batch, batch_outputs,
|
||||
num_context_logits_prefix_sum)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
error_msg = str(e)
|
||||
|
||||
@ -5,7 +5,6 @@ from typing import List, Literal, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.pyexecutor.handle_logits import HandleLogits
|
||||
from tensorrt_llm._torch.pyexecutor.make_decoding_batch_input_output import \
|
||||
MakeDecodingBatchInputOutput
|
||||
from tensorrt_llm._utils import nvtx_range, torch_dtype_to_binding
|
||||
@ -30,7 +29,6 @@ from .scheduler import ScheduledRequests
|
||||
@dataclass(kw_only=True)
|
||||
class SampleStateTensors:
|
||||
new_tokens: torch.Tensor
|
||||
logits: torch.Tensor | None = None
|
||||
log_probs: torch.Tensor | None = None
|
||||
|
||||
def values(self):
|
||||
@ -58,14 +56,24 @@ class Sampler(ABC):
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def sample_async(self, scheduled_requests: ScheduledRequests,
|
||||
model_outputs) -> SampleState:
|
||||
def sample_async(self, scheduled_requests: ScheduledRequests, model_outputs,
|
||||
num_context_logits_prefix_sum: list[int]) -> SampleState:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def update_requests(self, state: SampleState) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def beam_width(scheduled_requests: Iterable[LlmRequest]) -> int:
|
||||
for req in scheduled_requests:
|
||||
return req.sampling_config.beam_width
|
||||
return 0
|
||||
|
||||
@abstractmethod
|
||||
def is_generation_model(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class EarlyStopSampler(Sampler):
|
||||
"""
|
||||
@ -73,10 +81,9 @@ class EarlyStopSampler(Sampler):
|
||||
such as encoder-only model (e.g., BERT) or reward models that only need context phase.
|
||||
"""
|
||||
|
||||
def sample_async(self, scheduled_requests: ScheduledRequests,
|
||||
model_outputs) -> SampleState:
|
||||
host = SampleStateTensors(logits=model_outputs['logits'],
|
||||
new_tokens=torch.empty(0))
|
||||
def sample_async(self, scheduled_requests: ScheduledRequests, model_outputs,
|
||||
num_context_logits_prefix_sum: list[int]) -> SampleState:
|
||||
host = SampleStateTensors(new_tokens=torch.empty(0))
|
||||
return SampleState(scheduled_requests=scheduled_requests, host=host)
|
||||
|
||||
def update_requests(self, state: SampleState) -> None:
|
||||
@ -87,14 +94,9 @@ class EarlyStopSampler(Sampler):
|
||||
request.state = LlmRequestState.GENERATION_COMPLETE
|
||||
# NOTE: This is a hack: set finish reason manually and set the beam 0
|
||||
request.set_finished_reason(FinishReason.LENGTH, 0)
|
||||
if request.py_return_context_logits:
|
||||
logits = state.host.logits[idx]
|
||||
if logits.ndim == 1:
|
||||
# For BERT: Add axis to be compatible with LogitsStorage
|
||||
# (LogitsStorage will interpret this dim as the prompt_len which
|
||||
# is not relevant for outputting logits of encoder only model).
|
||||
logits = logits.unsqueeze(0)
|
||||
request.py_result.append_context_logits(logits)
|
||||
|
||||
def is_generation_model(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
@ -117,8 +119,10 @@ class EarlyStopWithMMResult(Sampler):
|
||||
Use for skipping decoding step for non generation model, and return the batch_output (such as mm_embeddings)
|
||||
"""
|
||||
|
||||
def sample_async(self, scheduled_requests: ScheduledRequests,
|
||||
model_outputs) -> SampleStateWithMMResult:
|
||||
def sample_async(
|
||||
self, scheduled_requests: ScheduledRequests, model_outputs,
|
||||
num_context_logits_prefix_sum: list[int]
|
||||
) -> SampleStateWithMMResult:
|
||||
# from model_outputs to MultimodalResult
|
||||
data = MultimodalResult(mm_embeddings=model_outputs['mm_embeddings'])
|
||||
return SampleStateWithMMResult(scheduled_requests=scheduled_requests,
|
||||
@ -141,6 +145,9 @@ class EarlyStopWithMMResult(Sampler):
|
||||
|
||||
request.py_result.append_mm_embeddings(mm_embedding)
|
||||
|
||||
def is_generation_model(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def top_k_sampling_batch(logits,
|
||||
top_k=50,
|
||||
@ -352,6 +359,9 @@ class TorchSampler(Sampler):
|
||||
BEAM = 0
|
||||
MAX_BEAM_WIDTH = BEAM + 1
|
||||
|
||||
def is_generation_model(self) -> bool:
|
||||
return True
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Store:
|
||||
new_tokens: torch.Tensor
|
||||
@ -445,13 +455,9 @@ class TorchSampler(Sampler):
|
||||
|
||||
return False
|
||||
|
||||
def handle_logits(self, request: LlmRequest, state: SampleState, *,
|
||||
beam: int, count: int):
|
||||
def handle_logprobs(self, request: LlmRequest, state: SampleState, *,
|
||||
beam: int, count: int):
|
||||
current_slice = slice(0, count), request.py_seq_slot, beam
|
||||
if request.py_return_generation_logits:
|
||||
assert state.host.logits is not None
|
||||
current_logits = state.host.logits[current_slice]
|
||||
request.py_result.append_generation_logits(current_logits)
|
||||
if request.py_return_log_probs:
|
||||
assert state.host.log_probs is not None
|
||||
log_probs = state.host.log_probs[request.py_seq_slot][beam][:count]
|
||||
@ -546,7 +552,7 @@ class TorchSampler(Sampler):
|
||||
continue
|
||||
new_token = add_token(req, new_tokens, beam=self.BEAM)
|
||||
self._handle_stop_criteria(req, new_token)
|
||||
self.handle_logits(req, state, beam=self.BEAM, count=1)
|
||||
self.handle_logprobs(req, state, beam=self.BEAM, count=1)
|
||||
req.py_decoding_iter += 1
|
||||
|
||||
for req in state.scheduled_requests.generation_requests:
|
||||
@ -558,37 +564,28 @@ class TorchSampler(Sampler):
|
||||
req.py_num_accepted_draft_tokens = num_accepted
|
||||
req.py_rewind_len = req.py_draft_pages_allocated - num_accepted
|
||||
processed += num_accepted
|
||||
self.handle_logits(req, state, beam=self.BEAM, count=processed)
|
||||
self.handle_logprobs(req, state, beam=self.BEAM, count=processed)
|
||||
req.py_decoding_iter += 1
|
||||
|
||||
def log_probs_host(self, requests: Iterable[LlmRequest]):
|
||||
def log_probs_host(self, scheduled_requests: ScheduledRequests):
|
||||
"""Shape: In lockstep with TRTLLMSampler: https://github.com/NVIDIA/TensorRT-LLM/blob/cea5dd1e3883b18bf50901a7f196f50a9544c28c/cpp/include/tensorrt_llm/runtime/decoderState.h#L103"""
|
||||
if any(req.py_return_log_probs for req in requests):
|
||||
if any(req.py_return_log_probs
|
||||
for req in scheduled_requests.all_requests()):
|
||||
return torch.empty(
|
||||
(self.max_num_sequences, self.MAX_BEAM_WIDTH, self.max_tokens),
|
||||
device="cpu",
|
||||
pin_memory=True)
|
||||
return None
|
||||
|
||||
def gen_logits_host(self, requests: Iterable[LlmRequest], vocab_size: int):
|
||||
if any(req.py_return_generation_logits for req in requests):
|
||||
return torch.empty((self.max_tokens, self.max_num_sequences,
|
||||
self.MAX_BEAM_WIDTH, vocab_size),
|
||||
device="cpu",
|
||||
pin_memory=True)
|
||||
return None
|
||||
|
||||
def sample_async(self, scheduled_requests: ScheduledRequests,
|
||||
model_outputs: dict[str, torch.Tensor]) -> SampleState:
|
||||
requests = scheduled_requests.all_requests()
|
||||
model_outputs: dict[str, torch.Tensor],
|
||||
num_context_logits_prefix_sum: list[int]) -> SampleState:
|
||||
new_tokens = self.store.new_tokens
|
||||
vocab_size = model_outputs["logits"].shape[-1]
|
||||
log_probs_host = self.log_probs_host(requests)
|
||||
gen_logits_host = self.gen_logits_host(requests, vocab_size)
|
||||
self._process_requests(requests,
|
||||
log_probs_host = self.log_probs_host(scheduled_requests)
|
||||
self._process_requests(scheduled_requests,
|
||||
model_outputs,
|
||||
new_tokens,
|
||||
gen_logits_host=gen_logits_host,
|
||||
num_context_logits_prefix_sum,
|
||||
log_probs_host=log_probs_host)
|
||||
new_tokens_host = new_tokens.to(device="cpu", non_blocking=True)
|
||||
sampler_event = torch.cuda.Event()
|
||||
@ -596,8 +593,7 @@ class TorchSampler(Sampler):
|
||||
return SampleState(scheduled_requests=scheduled_requests,
|
||||
device=SampleStateTensors(new_tokens=new_tokens),
|
||||
host=SampleStateTensors(new_tokens=new_tokens_host,
|
||||
log_probs=log_probs_host,
|
||||
logits=gen_logits_host),
|
||||
log_probs=log_probs_host),
|
||||
sampler_event=sampler_event)
|
||||
|
||||
@staticmethod
|
||||
@ -659,19 +655,37 @@ class TorchSampler(Sampler):
|
||||
return logits
|
||||
|
||||
def _process_requests(self,
|
||||
requests: list[LlmRequest],
|
||||
scheduled_requests: ScheduledRequests,
|
||||
model_outputs: dict[str, torch.Tensor],
|
||||
new_tokens: torch.Tensor,
|
||||
num_context_logits_prefix_sum: list[int],
|
||||
*,
|
||||
gen_logits_host: torch.Tensor | None = None,
|
||||
log_probs_host: torch.Tensor | None = None):
|
||||
beam_width = self.MAX_BEAM_WIDTH
|
||||
beam = self.BEAM
|
||||
raw_logits = model_outputs["logits"]
|
||||
|
||||
# raw_logits should contain only the logits from the gen requests.
|
||||
# If return context logits is requested, fetch only the logits from gen requests.
|
||||
if any(r.py_return_context_logits
|
||||
for r in scheduled_requests.context_requests):
|
||||
gen_logits_indices = []
|
||||
total_context_logits = num_context_logits_prefix_sum[-1]
|
||||
for i in range(len(scheduled_requests.context_requests)):
|
||||
gen_logits_indices.append(num_context_logits_prefix_sum[i + 1] -
|
||||
1)
|
||||
gen_logits_indices.extend(
|
||||
range(
|
||||
total_context_logits, total_context_logits +
|
||||
len(scheduled_requests.generation_requests)))
|
||||
raw_logits = model_outputs["logits"][gen_logits_indices]
|
||||
else:
|
||||
raw_logits = model_outputs["logits"]
|
||||
|
||||
requests = scheduled_requests.all_requests()
|
||||
num_steps = [1 + get_draft_token_length(req) for req in requests]
|
||||
sum_steps = sum(num_steps)
|
||||
no_draft_tokens = len(requests) == sum_steps
|
||||
fast_path = not self.enable_mixed_sampler and no_draft_tokens and gen_logits_host is None and log_probs_host is None
|
||||
fast_path = not self.enable_mixed_sampler and no_draft_tokens and log_probs_host is None
|
||||
|
||||
seq_slots_host = torch.as_tensor([r.py_seq_slot for r in requests])
|
||||
seq_slots = seq_slots_host.to(device="cuda", non_blocking=True)
|
||||
@ -727,8 +741,6 @@ class TorchSampler(Sampler):
|
||||
new_tokens[current_slice] = next_tokens
|
||||
if request.py_draft_logits is not None:
|
||||
request.py_target_probs = softmax.clone()
|
||||
if gen_logits_host is not None:
|
||||
gen_logits_host[current_slice].copy_(logits, non_blocking=True)
|
||||
if log_probs_host is not None:
|
||||
assert beam == 0, "The following call relies on beam_width to be 1 - hence the unsqueeze"
|
||||
token_probs = torch.gather(
|
||||
@ -769,6 +781,9 @@ class TRTLLMSampler(Sampler):
|
||||
MAX_DECODING_TOKENS = 1 # It must be 1 when not in speculative decoding
|
||||
SampleState = SampleStateTRTLLM
|
||||
|
||||
def is_generation_model(self) -> bool:
|
||||
return True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
executor_config: ExecutorConfig,
|
||||
@ -864,7 +879,6 @@ class TRTLLMSampler(Sampler):
|
||||
speculative_decoding_fast_logits=False,
|
||||
is_leader_in_orch_mode=False,
|
||||
is_normalize_log_probs=False)
|
||||
self.algs.handle_logits = HandleLogits()
|
||||
self.algs.make_decoding_batch_input_output = MakeDecodingBatchInputOutput(
|
||||
)
|
||||
|
||||
@ -898,13 +912,6 @@ class TRTLLMSampler(Sampler):
|
||||
slots = torch.tensor([r.py_seq_slot for r in adp], dtype=torch.int32)
|
||||
self.algs.decoder.underlying_decoder().setup(config, batch_size, slots)
|
||||
|
||||
@staticmethod
|
||||
@torch.inference_mode()
|
||||
def beam_width(scheduled_requests: Iterable[LlmRequest]) -> int:
|
||||
for req in scheduled_requests:
|
||||
return req.sampling_config.beam_width
|
||||
return 0
|
||||
|
||||
def get_cache_indirection(self) -> torch.Tensor | None:
|
||||
return self.store["decoder_state"].cache_indirection_output
|
||||
|
||||
@ -920,8 +927,9 @@ class TRTLLMSampler(Sampler):
|
||||
|
||||
@torch.inference_mode()
|
||||
@nvtx_range("sample_async")
|
||||
def sample_async(self, scheduled_requests: ScheduledRequests,
|
||||
model_outputs) -> SampleStateTRTLLM:
|
||||
def sample_async(
|
||||
self, scheduled_requests: ScheduledRequests, model_outputs,
|
||||
num_context_logits_prefix_sum: list[int]) -> SampleStateTRTLLM:
|
||||
|
||||
batch_size = scheduled_requests.batch_size
|
||||
beam_width = self.beam_width(scheduled_requests.all_requests())
|
||||
@ -934,29 +942,10 @@ class TRTLLMSampler(Sampler):
|
||||
|
||||
self.setup_sampler_step(scheduled_requests)
|
||||
|
||||
num_context_logits_prefix_sum = [0]
|
||||
prefix_sum = 0
|
||||
for request in scheduled_requests.context_requests:
|
||||
prefix_sum += request.context_chunk_size if request.py_return_context_logits else 1
|
||||
num_context_logits_prefix_sum.append(prefix_sum)
|
||||
|
||||
if any(r.py_return_context_logits or r.py_return_generation_logits
|
||||
for r in scheduled_requests.all_requests()):
|
||||
self.algs.handle_logits(scheduled_requests.context_requests,
|
||||
scheduled_requests.generation_requests,
|
||||
model_outputs["logits"],
|
||||
num_context_logits_prefix_sum,
|
||||
self.max_num_sequences, beam_width)
|
||||
|
||||
# For beam search, cache indirection needs to be updated
|
||||
if beam_width > 1:
|
||||
self._update_cache_indirection_buffer(scheduled_requests)
|
||||
|
||||
# TODO: Enable this back once nanobind is merged and/or llm request is a pure python object
|
||||
# decoding_input = self.algs.make_decoding_batch_input_output(
|
||||
# scheduled_requests, model_outputs["logits"], beam_width,
|
||||
# num_context_logits_prefix_sum)
|
||||
|
||||
self.store["decoding_input"][
|
||||
self.micro_batch_idx] = make_decoding_batch_input(
|
||||
scheduled_requests.context_requests,
|
||||
|
||||
@ -9,6 +9,7 @@ from tensorrt_llm._utils import nvtx_range
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
from ..pyexecutor.guided_decoder import GuidedDecoder
|
||||
from ..pyexecutor.handle_logits import HandleLogits
|
||||
from ..pyexecutor.llm_request import (LlmRequest, LlmRequestState,
|
||||
get_draft_token_length)
|
||||
from ..pyexecutor.resource_manager import BaseResourceManager, ResourceManager
|
||||
@ -266,7 +267,21 @@ class ModelDrafter(Drafter):
|
||||
"""Sample tokens from draft model outputs."""
|
||||
try:
|
||||
if self.sampler is not None:
|
||||
return self.sampler.sample_async(draft_batch, outputs)
|
||||
num_context_logits_prefix_sum = [0]
|
||||
prefix_sum = 0
|
||||
for request in draft_batch.context_requests:
|
||||
prefix_sum += request.context_chunk_size if request.py_return_context_logits else 1
|
||||
num_context_logits_prefix_sum.append(prefix_sum)
|
||||
|
||||
HandleLogits()(
|
||||
draft_batch.context_requests,
|
||||
draft_batch.generation_requests, outputs["logits"],
|
||||
self.sampler.beam_width(draft_batch.all_requests()),
|
||||
num_context_logits_prefix_sum,
|
||||
self.sampler.is_generation_model())
|
||||
|
||||
return self.sampler.sample_async(draft_batch, outputs,
|
||||
num_context_logits_prefix_sum)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error in sampling: {str(e)}")
|
||||
|
||||
@ -268,8 +268,10 @@ class MTPSampler(TorchSampler):
|
||||
req.py_rewind_len = self.draft_len - (num_new_tokens - 1)
|
||||
self._request_common_handling(req, next_draft_tokens_list)
|
||||
|
||||
def sample_async(self, scheduled_requests: ScheduledRequests,
|
||||
outputs: dict[str, torch.Tensor]) -> SampleStateMTP:
|
||||
def sample_async(
|
||||
self, scheduled_requests: ScheduledRequests,
|
||||
outputs: dict[str, torch.Tensor],
|
||||
num_context_logits_prefix_sum: list[int]) -> SampleStateMTP:
|
||||
# new_tokens_device: accepted tokens, device tensor, shape: batch_size, nextn + 1
|
||||
# new_tokens_lens_device: accepted lengths, device tensor, shape: batch_size
|
||||
# next_draft_tokens_device: predicted draft tokens, device tensor, shape: batch_size, nextn
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from defs.conftest import get_sm_version
|
||||
|
||||
from tensorrt_llm import LLM
|
||||
@ -398,6 +399,40 @@ class TestLlama3_2_1B(LlmapiAccuracyTestHarness):
|
||||
task = CnnDailymail(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@skip_pre_hopper
|
||||
@pytest.mark.skip_less_device(4)
|
||||
@pytest.mark.parametrize("disable_overlap_scheduler", [True, False])
|
||||
@pytest.mark.parametrize("pp_size", [2, 4], ids=["pp2", "pp4"])
|
||||
def test_return_logits_pp(self, pp_size, disable_overlap_scheduler):
|
||||
prompts = ["A B C"]
|
||||
|
||||
llm = LLM(model=self.MODEL_PATH,
|
||||
pipeline_parallel_size=pp_size,
|
||||
disable_overlap_scheduler=disable_overlap_scheduler)
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=8,
|
||||
return_context_logits=True,
|
||||
return_generation_logits=True,
|
||||
logprobs=True)
|
||||
|
||||
with llm:
|
||||
for output in llm.generate(prompts,
|
||||
sampling_params=sampling_params):
|
||||
assert output.context_logits is not None
|
||||
# NOTE: prompt_token_ids of "A B C" becomes [1, 319, 350, 315]
|
||||
expected_len = len(prompts[0].split()) + 1
|
||||
assert expected_len == output.context_logits.shape[0]
|
||||
|
||||
gen_logits = output.outputs[0].generation_logits
|
||||
assert gen_logits is not None
|
||||
assert gen_logits.ndim == 2
|
||||
assert gen_logits.shape[0] == sampling_params.max_tokens
|
||||
assert torch.argmax(
|
||||
gen_logits, dim=1).tolist() == output.outputs[0].token_ids
|
||||
|
||||
assert len(
|
||||
output.outputs[0].logprobs) == sampling_params.max_tokens
|
||||
|
||||
|
||||
class TestLlama3_2_3B(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "meta-llama/Llama-3.2-3B"
|
||||
|
||||
@ -27,9 +27,6 @@ def test_generate_with_return_logits(disable_overlap_scheduler: bool,
|
||||
or return_log_probs): # prune space
|
||||
pytest.skip("Nothing to test")
|
||||
|
||||
if sampler_type == "TorchSampler" and gather_context_logits:
|
||||
pytest.skip("TorchSampler does not support gather_context_logits")
|
||||
|
||||
build_config = BuildConfig()
|
||||
build_config.gather_context_logits = gather_context_logits
|
||||
|
||||
@ -94,9 +91,6 @@ def test_generate_async_with_return_logits(disable_overlap_scheduler: bool,
|
||||
or return_log_probs): # prune space
|
||||
pytest.skip("Nothing to test")
|
||||
|
||||
if sampler_type == "TorchSampler" and gather_context_logits:
|
||||
pytest.skip("TorchSampler does not support gather_context_logits")
|
||||
|
||||
build_config = BuildConfig()
|
||||
build_config.gather_context_logits = gather_context_logits
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user