From 099f081e03ef8c92cadeb148915edad368b6daca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20C=C3=A1mpora?= <961215+dcampora@users.noreply.github.com> Date: Fri, 22 Aug 2025 08:09:30 +0200 Subject: [PATCH] [TRTLLM-7155][feat] Unify sampler handle logits implementation. (#6867) Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> --- .../pyexecutor/executor_request_queue.py | 11 +- .../_torch/pyexecutor/handle_logits.py | 23 ++- tensorrt_llm/_torch/pyexecutor/py_executor.py | 98 ++++++------ tensorrt_llm/_torch/pyexecutor/sampler.py | 149 ++++++++---------- .../_torch/speculative/model_drafter.py | 17 +- tensorrt_llm/_torch/speculative/mtp.py | 6 +- .../defs/accuracy/test_llm_api_pytorch.py | 35 ++++ .../_torch/sampler/test_return_logits.py | 6 - 8 files changed, 201 insertions(+), 144 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py b/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py index 96c5957ef9..8cfccb020a 100644 --- a/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py +++ b/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py @@ -16,7 +16,6 @@ from tensorrt_llm.mapping import CpType from ..distributed import Distributed from .llm_request import (ExecutorRequest, LlmRequest, executor_request_to_llm_request) -from .sampler import Sampler, TorchSampler SHUTDOWN_REQUEST_ID = -1 @@ -707,21 +706,19 @@ class ExecutorRequestQueue: def set_exclude_last_generation_logits(self, disable_overlap_scheduler: bool, - sampler: Sampler) -> None: + pp_size: int) -> None: # When overlap scheduler is enabled then when starting to handle a new prompt, # sample_async is called twice before the first call to update_requests: # - 1st time as a context request that handles on the 1st generated token # - 2nd time as a generation request that handles on the 2nd generated token. # and only after these two calls the sampler's update_request method is called. # So in a sampler that works by the expected flow of handling the logits in - # sample_async (TorchSampler is an anomaly that instead does that on - # update_requests), every update_request doesn't handle the newest token, but one + # sample_async, every update_request doesn't handle the newest token, but one # before it. Since all these calls work on the same request object, then its # logits storage contains the logits of both the token update_requests should work # on, and also its next token. Thus, excluding the last generation logits from any - # getter is required, when not using TorchSampler. - self.should_exclude_last_generation_logits = not disable_overlap_scheduler and not isinstance( - sampler, TorchSampler) + # getter is required. + self.should_exclude_last_generation_logits = not disable_overlap_scheduler and pp_size == 1 def _should_exclude_last_generation_logits(self) -> bool: return self.should_exclude_last_generation_logits diff --git a/tensorrt_llm/_torch/pyexecutor/handle_logits.py b/tensorrt_llm/_torch/pyexecutor/handle_logits.py index 81986df593..b3d7ced6a5 100644 --- a/tensorrt_llm/_torch/pyexecutor/handle_logits.py +++ b/tensorrt_llm/_torch/pyexecutor/handle_logits.py @@ -1,3 +1,4 @@ +from itertools import chain from typing import List import torch @@ -16,9 +17,9 @@ class HandleLogits: context_requests: List[LlmRequest], generation_requests: List[LlmRequest], logits: torch.Tensor, - num_context_logits_prefix_sum: List[int], - max_num_sequences: int, beam_width: int, + num_context_logits_prefix_sum: list[int], + is_generation_model: bool, ): """Handles context and generation logits for a batch of requests. @@ -26,10 +27,24 @@ class HandleLogits: context_requests: List of context requests to process generation_requests: List of generation requests to process logits: Input logits tensor - num_context_logits_prefix_sum: Prefix sum of context logits for each request - max_num_sequences: Maximum number of sequences to process beam_width: Beam width for the generation requests + num_context_logits_prefix_sum: Prefix sum of the logits + is_generation_model: Bool containing whether the model is generation or not """ + if not any(r.py_return_context_logits or r.py_return_generation_logits + for r in chain(context_requests, generation_requests)): + return + + if not is_generation_model: + for llm_req, logits_temp in zip(context_requests, logits): + if logits_temp.ndim == 1: + # For BERT: Add axis to be compatible with LogitsStorage + # (LogitsStorage will interpret this dim as the prompt_len which + # is not relevant for outputting logits of encoder only model). + logits_temp = logits_temp.unsqueeze(0) + llm_req.py_result.append_context_logits(logits_temp) + return + # Copy logits into decoderBuffers.logits for batch_index, llm_req in enumerate(context_requests): logits_begin = num_context_logits_prefix_sum[batch_index] diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index a40b9b9045..453434d9d6 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -39,6 +39,7 @@ from ..models.modeling_utils import DecoderModelForCausalLM from ..speculative.drafter import Drafter from .executor_request_queue import ExecutorRequestQueue, RequestQueueItem from .guided_decoder import GuidedDecoder +from .handle_logits import HandleLogits from .kv_cache_transceiver import KvCacheTransceiver from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState, LlmResponse) @@ -244,7 +245,7 @@ class PyExecutor: is_disaggregated=kv_cache_transceiver is not None, ) self.executor_request_queue.set_exclude_last_generation_logits( - self.disable_overlap_scheduler, self.sampler) + self.disable_overlap_scheduler, self.dist.pp_size) self.stats_lock = threading.Lock() self.stats = [] @@ -681,24 +682,6 @@ class PyExecutor: self.response_cv.notify_all() self.shutdown_event.set() - def _need_return_logits(self, scheduled_requests: ScheduledRequests): - for req in scheduled_requests.context_requests: - if req.py_return_context_logits: - return True - for req in scheduled_requests.generation_requests: - if req.py_return_generation_logits: - return True - return False - - def _need_return_log_probs(self, scheduled_requests: ScheduledRequests): - for req in scheduled_requests.context_requests: - if req.py_return_log_probs: - return True - for req in scheduled_requests.generation_requests: - if req.py_return_log_probs: - return True - return False - def _executor_loop_pp(self): logger.debug(f"Starting executor loop for pp_rank {self.dist.pp_rank}") torch.cuda.set_device(self.device_id) @@ -790,10 +773,6 @@ class PyExecutor: else: with torch.cuda.nvtx.range("_forward_step_last_pp"): batch_outputs = self._forward_step(scheduled_batch) - logits_host = None - if self._need_return_logits(scheduled_batch): - logits_host = batch_outputs["logits"].to( - "cpu", non_blocking=True) if self.kv_cache_transceiver and self.guided_decoder: self.guided_decoder.init_disagg_gen_requests( scheduled_batch) @@ -802,7 +781,6 @@ class PyExecutor: sample_state = self._sample_async( scheduled_batch, batch_outputs) - sample_state.host.logits = logits_host self._update_request_states(scheduled_batch) if self.enable_iter_perf_stats: @@ -832,18 +810,10 @@ class PyExecutor: torch.cuda.nvtx.range_push( "_handle_new_tokens_inter_pp") # Receive tokens from previous pp rank (w.r.t model forward direction) - ( - logits, - sample_state.host, - ) = self.dist.recv_object( + sample_state.host = self.dist.recv_object( src=self.dist.prev_pp_rank, tag=prev_microbatch_id, ) - if logits is not None: - logits_host = torch.from_numpy(logits) - sample_state.host.logits = logits_host - sample_state.device.logits = logits_host.to( - self.device_id) else: torch.cuda.nvtx.range_push("_handle_new_tokens_last_pp") sample_state.sampler_event.synchronize() @@ -853,18 +823,9 @@ class PyExecutor: if not self.dist.is_second_last_pp_rank: if self.send_handles[prev_microbatch_id] is not None: self.send_handles[prev_microbatch_id].wait() - needs_logits = ( - self._need_return_logits(scheduled_batch) - or (self._need_return_log_probs(scheduled_batch) - and sample_state.host.log_probs is not None)) - serialized_logits = sample_state.host.logits.numpy( - ) if needs_logits else None self.send_handles[ prev_microbatch_id] = self.dist.isend_object( - ( - serialized_logits, - sample_state.host, - ), + sample_state.host, dest=self.dist.next_pp_rank, tag=prev_microbatch_id) torch.cuda.nvtx.range_pop() @@ -884,6 +845,40 @@ class PyExecutor: previous_batch.scheduled_ctx_reqs) self._handle_canceled_requests() + + # If logits were requested last PP rank has to send to first PP rank (who sends responses) the + # logits of the requests that have finished. + # NOTE: If the rank processing the logits ever becomes the same as + # the rank sending the responses, this code can be removed. + finished_reqs = [ + r for r in previous_batch.sample_state. + scheduled_requests.all_requests() + if r.state == LlmRequestState.GENERATION_COMPLETE + and (r.py_return_context_logits + or r.py_return_generation_logits) + ] + if self.dist.is_first_pp_rank and len(finished_reqs): + finished_reqs_py_results = [ + r.py_result for r in finished_reqs + ] + finished_reqs_py_results = self.dist.recv_object( + src=self.dist.prev_pp_rank, + tag=prev_microbatch_id, + ) + for req, py_result in zip(finished_reqs, + finished_reqs_py_results): + req.py_result = py_result + + elif self.dist.is_last_pp_rank and len(finished_reqs): + if self.send_handles[ + prev_microbatch_id] is not None: + self.send_handles[prev_microbatch_id].wait() + self.send_handles[ + prev_microbatch_id] = self.dist.isend_object( + [r.py_result for r in finished_reqs], + dest=self.dist.next_pp_rank, + tag=prev_microbatch_id) + finished_requests = self._handle_responses() previous_scheduled_batch = previous_batch.sample_state.scheduled_requests self.resource_manager.update_resources( @@ -1538,7 +1533,22 @@ class PyExecutor: batch_outputs) -> SampleState | None: try: if batch_outputs is not None: - return self.sampler.sample_async(scheduled_batch, batch_outputs) + num_context_logits_prefix_sum = [0] + prefix_sum = 0 + for request in scheduled_batch.context_requests: + prefix_sum += request.context_chunk_size if request.py_return_context_logits else 1 + num_context_logits_prefix_sum.append(prefix_sum) + + HandleLogits()(scheduled_batch.context_requests, + scheduled_batch.generation_requests, + batch_outputs["logits"], + self.sampler.beam_width( + scheduled_batch.all_requests()), + num_context_logits_prefix_sum, + self.sampler.is_generation_model()) + + return self.sampler.sample_async(scheduled_batch, batch_outputs, + num_context_logits_prefix_sum) except Exception as e: traceback.print_exc() error_msg = str(e) diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 919b99be2d..e6d19a9df4 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -5,7 +5,6 @@ from typing import List, Literal, Optional import torch -from tensorrt_llm._torch.pyexecutor.handle_logits import HandleLogits from tensorrt_llm._torch.pyexecutor.make_decoding_batch_input_output import \ MakeDecodingBatchInputOutput from tensorrt_llm._utils import nvtx_range, torch_dtype_to_binding @@ -30,7 +29,6 @@ from .scheduler import ScheduledRequests @dataclass(kw_only=True) class SampleStateTensors: new_tokens: torch.Tensor - logits: torch.Tensor | None = None log_probs: torch.Tensor | None = None def values(self): @@ -58,14 +56,24 @@ class Sampler(ABC): return None @abstractmethod - def sample_async(self, scheduled_requests: ScheduledRequests, - model_outputs) -> SampleState: + def sample_async(self, scheduled_requests: ScheduledRequests, model_outputs, + num_context_logits_prefix_sum: list[int]) -> SampleState: raise NotImplementedError @abstractmethod def update_requests(self, state: SampleState) -> None: raise NotImplementedError + @staticmethod + def beam_width(scheduled_requests: Iterable[LlmRequest]) -> int: + for req in scheduled_requests: + return req.sampling_config.beam_width + return 0 + + @abstractmethod + def is_generation_model(self) -> bool: + raise NotImplementedError + class EarlyStopSampler(Sampler): """ @@ -73,10 +81,9 @@ class EarlyStopSampler(Sampler): such as encoder-only model (e.g., BERT) or reward models that only need context phase. """ - def sample_async(self, scheduled_requests: ScheduledRequests, - model_outputs) -> SampleState: - host = SampleStateTensors(logits=model_outputs['logits'], - new_tokens=torch.empty(0)) + def sample_async(self, scheduled_requests: ScheduledRequests, model_outputs, + num_context_logits_prefix_sum: list[int]) -> SampleState: + host = SampleStateTensors(new_tokens=torch.empty(0)) return SampleState(scheduled_requests=scheduled_requests, host=host) def update_requests(self, state: SampleState) -> None: @@ -87,14 +94,9 @@ class EarlyStopSampler(Sampler): request.state = LlmRequestState.GENERATION_COMPLETE # NOTE: This is a hack: set finish reason manually and set the beam 0 request.set_finished_reason(FinishReason.LENGTH, 0) - if request.py_return_context_logits: - logits = state.host.logits[idx] - if logits.ndim == 1: - # For BERT: Add axis to be compatible with LogitsStorage - # (LogitsStorage will interpret this dim as the prompt_len which - # is not relevant for outputting logits of encoder only model). - logits = logits.unsqueeze(0) - request.py_result.append_context_logits(logits) + + def is_generation_model(self) -> bool: + return False @dataclass(kw_only=True) @@ -117,8 +119,10 @@ class EarlyStopWithMMResult(Sampler): Use for skipping decoding step for non generation model, and return the batch_output (such as mm_embeddings) """ - def sample_async(self, scheduled_requests: ScheduledRequests, - model_outputs) -> SampleStateWithMMResult: + def sample_async( + self, scheduled_requests: ScheduledRequests, model_outputs, + num_context_logits_prefix_sum: list[int] + ) -> SampleStateWithMMResult: # from model_outputs to MultimodalResult data = MultimodalResult(mm_embeddings=model_outputs['mm_embeddings']) return SampleStateWithMMResult(scheduled_requests=scheduled_requests, @@ -141,6 +145,9 @@ class EarlyStopWithMMResult(Sampler): request.py_result.append_mm_embeddings(mm_embedding) + def is_generation_model(self) -> bool: + return False + def top_k_sampling_batch(logits, top_k=50, @@ -352,6 +359,9 @@ class TorchSampler(Sampler): BEAM = 0 MAX_BEAM_WIDTH = BEAM + 1 + def is_generation_model(self) -> bool: + return True + @dataclass(frozen=True, kw_only=True) class Store: new_tokens: torch.Tensor @@ -445,13 +455,9 @@ class TorchSampler(Sampler): return False - def handle_logits(self, request: LlmRequest, state: SampleState, *, - beam: int, count: int): + def handle_logprobs(self, request: LlmRequest, state: SampleState, *, + beam: int, count: int): current_slice = slice(0, count), request.py_seq_slot, beam - if request.py_return_generation_logits: - assert state.host.logits is not None - current_logits = state.host.logits[current_slice] - request.py_result.append_generation_logits(current_logits) if request.py_return_log_probs: assert state.host.log_probs is not None log_probs = state.host.log_probs[request.py_seq_slot][beam][:count] @@ -546,7 +552,7 @@ class TorchSampler(Sampler): continue new_token = add_token(req, new_tokens, beam=self.BEAM) self._handle_stop_criteria(req, new_token) - self.handle_logits(req, state, beam=self.BEAM, count=1) + self.handle_logprobs(req, state, beam=self.BEAM, count=1) req.py_decoding_iter += 1 for req in state.scheduled_requests.generation_requests: @@ -558,37 +564,28 @@ class TorchSampler(Sampler): req.py_num_accepted_draft_tokens = num_accepted req.py_rewind_len = req.py_draft_pages_allocated - num_accepted processed += num_accepted - self.handle_logits(req, state, beam=self.BEAM, count=processed) + self.handle_logprobs(req, state, beam=self.BEAM, count=processed) req.py_decoding_iter += 1 - def log_probs_host(self, requests: Iterable[LlmRequest]): + def log_probs_host(self, scheduled_requests: ScheduledRequests): """Shape: In lockstep with TRTLLMSampler: https://github.com/NVIDIA/TensorRT-LLM/blob/cea5dd1e3883b18bf50901a7f196f50a9544c28c/cpp/include/tensorrt_llm/runtime/decoderState.h#L103""" - if any(req.py_return_log_probs for req in requests): + if any(req.py_return_log_probs + for req in scheduled_requests.all_requests()): return torch.empty( (self.max_num_sequences, self.MAX_BEAM_WIDTH, self.max_tokens), device="cpu", pin_memory=True) return None - def gen_logits_host(self, requests: Iterable[LlmRequest], vocab_size: int): - if any(req.py_return_generation_logits for req in requests): - return torch.empty((self.max_tokens, self.max_num_sequences, - self.MAX_BEAM_WIDTH, vocab_size), - device="cpu", - pin_memory=True) - return None - def sample_async(self, scheduled_requests: ScheduledRequests, - model_outputs: dict[str, torch.Tensor]) -> SampleState: - requests = scheduled_requests.all_requests() + model_outputs: dict[str, torch.Tensor], + num_context_logits_prefix_sum: list[int]) -> SampleState: new_tokens = self.store.new_tokens - vocab_size = model_outputs["logits"].shape[-1] - log_probs_host = self.log_probs_host(requests) - gen_logits_host = self.gen_logits_host(requests, vocab_size) - self._process_requests(requests, + log_probs_host = self.log_probs_host(scheduled_requests) + self._process_requests(scheduled_requests, model_outputs, new_tokens, - gen_logits_host=gen_logits_host, + num_context_logits_prefix_sum, log_probs_host=log_probs_host) new_tokens_host = new_tokens.to(device="cpu", non_blocking=True) sampler_event = torch.cuda.Event() @@ -596,8 +593,7 @@ class TorchSampler(Sampler): return SampleState(scheduled_requests=scheduled_requests, device=SampleStateTensors(new_tokens=new_tokens), host=SampleStateTensors(new_tokens=new_tokens_host, - log_probs=log_probs_host, - logits=gen_logits_host), + log_probs=log_probs_host), sampler_event=sampler_event) @staticmethod @@ -659,19 +655,37 @@ class TorchSampler(Sampler): return logits def _process_requests(self, - requests: list[LlmRequest], + scheduled_requests: ScheduledRequests, model_outputs: dict[str, torch.Tensor], new_tokens: torch.Tensor, + num_context_logits_prefix_sum: list[int], *, - gen_logits_host: torch.Tensor | None = None, log_probs_host: torch.Tensor | None = None): beam_width = self.MAX_BEAM_WIDTH beam = self.BEAM - raw_logits = model_outputs["logits"] + + # raw_logits should contain only the logits from the gen requests. + # If return context logits is requested, fetch only the logits from gen requests. + if any(r.py_return_context_logits + for r in scheduled_requests.context_requests): + gen_logits_indices = [] + total_context_logits = num_context_logits_prefix_sum[-1] + for i in range(len(scheduled_requests.context_requests)): + gen_logits_indices.append(num_context_logits_prefix_sum[i + 1] - + 1) + gen_logits_indices.extend( + range( + total_context_logits, total_context_logits + + len(scheduled_requests.generation_requests))) + raw_logits = model_outputs["logits"][gen_logits_indices] + else: + raw_logits = model_outputs["logits"] + + requests = scheduled_requests.all_requests() num_steps = [1 + get_draft_token_length(req) for req in requests] sum_steps = sum(num_steps) no_draft_tokens = len(requests) == sum_steps - fast_path = not self.enable_mixed_sampler and no_draft_tokens and gen_logits_host is None and log_probs_host is None + fast_path = not self.enable_mixed_sampler and no_draft_tokens and log_probs_host is None seq_slots_host = torch.as_tensor([r.py_seq_slot for r in requests]) seq_slots = seq_slots_host.to(device="cuda", non_blocking=True) @@ -727,8 +741,6 @@ class TorchSampler(Sampler): new_tokens[current_slice] = next_tokens if request.py_draft_logits is not None: request.py_target_probs = softmax.clone() - if gen_logits_host is not None: - gen_logits_host[current_slice].copy_(logits, non_blocking=True) if log_probs_host is not None: assert beam == 0, "The following call relies on beam_width to be 1 - hence the unsqueeze" token_probs = torch.gather( @@ -769,6 +781,9 @@ class TRTLLMSampler(Sampler): MAX_DECODING_TOKENS = 1 # It must be 1 when not in speculative decoding SampleState = SampleStateTRTLLM + def is_generation_model(self) -> bool: + return True + def __init__( self, executor_config: ExecutorConfig, @@ -864,7 +879,6 @@ class TRTLLMSampler(Sampler): speculative_decoding_fast_logits=False, is_leader_in_orch_mode=False, is_normalize_log_probs=False) - self.algs.handle_logits = HandleLogits() self.algs.make_decoding_batch_input_output = MakeDecodingBatchInputOutput( ) @@ -898,13 +912,6 @@ class TRTLLMSampler(Sampler): slots = torch.tensor([r.py_seq_slot for r in adp], dtype=torch.int32) self.algs.decoder.underlying_decoder().setup(config, batch_size, slots) - @staticmethod - @torch.inference_mode() - def beam_width(scheduled_requests: Iterable[LlmRequest]) -> int: - for req in scheduled_requests: - return req.sampling_config.beam_width - return 0 - def get_cache_indirection(self) -> torch.Tensor | None: return self.store["decoder_state"].cache_indirection_output @@ -920,8 +927,9 @@ class TRTLLMSampler(Sampler): @torch.inference_mode() @nvtx_range("sample_async") - def sample_async(self, scheduled_requests: ScheduledRequests, - model_outputs) -> SampleStateTRTLLM: + def sample_async( + self, scheduled_requests: ScheduledRequests, model_outputs, + num_context_logits_prefix_sum: list[int]) -> SampleStateTRTLLM: batch_size = scheduled_requests.batch_size beam_width = self.beam_width(scheduled_requests.all_requests()) @@ -934,29 +942,10 @@ class TRTLLMSampler(Sampler): self.setup_sampler_step(scheduled_requests) - num_context_logits_prefix_sum = [0] - prefix_sum = 0 - for request in scheduled_requests.context_requests: - prefix_sum += request.context_chunk_size if request.py_return_context_logits else 1 - num_context_logits_prefix_sum.append(prefix_sum) - - if any(r.py_return_context_logits or r.py_return_generation_logits - for r in scheduled_requests.all_requests()): - self.algs.handle_logits(scheduled_requests.context_requests, - scheduled_requests.generation_requests, - model_outputs["logits"], - num_context_logits_prefix_sum, - self.max_num_sequences, beam_width) - # For beam search, cache indirection needs to be updated if beam_width > 1: self._update_cache_indirection_buffer(scheduled_requests) - # TODO: Enable this back once nanobind is merged and/or llm request is a pure python object - # decoding_input = self.algs.make_decoding_batch_input_output( - # scheduled_requests, model_outputs["logits"], beam_width, - # num_context_logits_prefix_sum) - self.store["decoding_input"][ self.micro_batch_idx] = make_decoding_batch_input( scheduled_requests.context_requests, diff --git a/tensorrt_llm/_torch/speculative/model_drafter.py b/tensorrt_llm/_torch/speculative/model_drafter.py index 7f11142c3f..5d54f2f3be 100644 --- a/tensorrt_llm/_torch/speculative/model_drafter.py +++ b/tensorrt_llm/_torch/speculative/model_drafter.py @@ -9,6 +9,7 @@ from tensorrt_llm._utils import nvtx_range from tensorrt_llm.logger import logger from ..pyexecutor.guided_decoder import GuidedDecoder +from ..pyexecutor.handle_logits import HandleLogits from ..pyexecutor.llm_request import (LlmRequest, LlmRequestState, get_draft_token_length) from ..pyexecutor.resource_manager import BaseResourceManager, ResourceManager @@ -266,7 +267,21 @@ class ModelDrafter(Drafter): """Sample tokens from draft model outputs.""" try: if self.sampler is not None: - return self.sampler.sample_async(draft_batch, outputs) + num_context_logits_prefix_sum = [0] + prefix_sum = 0 + for request in draft_batch.context_requests: + prefix_sum += request.context_chunk_size if request.py_return_context_logits else 1 + num_context_logits_prefix_sum.append(prefix_sum) + + HandleLogits()( + draft_batch.context_requests, + draft_batch.generation_requests, outputs["logits"], + self.sampler.beam_width(draft_batch.all_requests()), + num_context_logits_prefix_sum, + self.sampler.is_generation_model()) + + return self.sampler.sample_async(draft_batch, outputs, + num_context_logits_prefix_sum) return None except Exception as e: logger.error(f"Error in sampling: {str(e)}") diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py index 2658ce539b..b31512df91 100644 --- a/tensorrt_llm/_torch/speculative/mtp.py +++ b/tensorrt_llm/_torch/speculative/mtp.py @@ -268,8 +268,10 @@ class MTPSampler(TorchSampler): req.py_rewind_len = self.draft_len - (num_new_tokens - 1) self._request_common_handling(req, next_draft_tokens_list) - def sample_async(self, scheduled_requests: ScheduledRequests, - outputs: dict[str, torch.Tensor]) -> SampleStateMTP: + def sample_async( + self, scheduled_requests: ScheduledRequests, + outputs: dict[str, torch.Tensor], + num_context_logits_prefix_sum: list[int]) -> SampleStateMTP: # new_tokens_device: accepted tokens, device tensor, shape: batch_size, nextn + 1 # new_tokens_lens_device: accepted lengths, device tensor, shape: batch_size # next_draft_tokens_device: predicted draft tokens, device tensor, shape: batch_size, nextn diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 618feaf928..0390c97e64 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -15,6 +15,7 @@ import os import pytest +import torch from defs.conftest import get_sm_version from tensorrt_llm import LLM @@ -398,6 +399,40 @@ class TestLlama3_2_1B(LlmapiAccuracyTestHarness): task = CnnDailymail(self.MODEL_NAME) task.evaluate(llm) + @skip_pre_hopper + @pytest.mark.skip_less_device(4) + @pytest.mark.parametrize("disable_overlap_scheduler", [True, False]) + @pytest.mark.parametrize("pp_size", [2, 4], ids=["pp2", "pp4"]) + def test_return_logits_pp(self, pp_size, disable_overlap_scheduler): + prompts = ["A B C"] + + llm = LLM(model=self.MODEL_PATH, + pipeline_parallel_size=pp_size, + disable_overlap_scheduler=disable_overlap_scheduler) + + sampling_params = SamplingParams(max_tokens=8, + return_context_logits=True, + return_generation_logits=True, + logprobs=True) + + with llm: + for output in llm.generate(prompts, + sampling_params=sampling_params): + assert output.context_logits is not None + # NOTE: prompt_token_ids of "A B C" becomes [1, 319, 350, 315] + expected_len = len(prompts[0].split()) + 1 + assert expected_len == output.context_logits.shape[0] + + gen_logits = output.outputs[0].generation_logits + assert gen_logits is not None + assert gen_logits.ndim == 2 + assert gen_logits.shape[0] == sampling_params.max_tokens + assert torch.argmax( + gen_logits, dim=1).tolist() == output.outputs[0].token_ids + + assert len( + output.outputs[0].logprobs) == sampling_params.max_tokens + class TestLlama3_2_3B(LlmapiAccuracyTestHarness): MODEL_NAME = "meta-llama/Llama-3.2-3B" diff --git a/tests/unittest/_torch/sampler/test_return_logits.py b/tests/unittest/_torch/sampler/test_return_logits.py index 0d6a5e28ca..a3af16c8bc 100644 --- a/tests/unittest/_torch/sampler/test_return_logits.py +++ b/tests/unittest/_torch/sampler/test_return_logits.py @@ -27,9 +27,6 @@ def test_generate_with_return_logits(disable_overlap_scheduler: bool, or return_log_probs): # prune space pytest.skip("Nothing to test") - if sampler_type == "TorchSampler" and gather_context_logits: - pytest.skip("TorchSampler does not support gather_context_logits") - build_config = BuildConfig() build_config.gather_context_logits = gather_context_logits @@ -94,9 +91,6 @@ def test_generate_async_with_return_logits(disable_overlap_scheduler: bool, or return_log_probs): # prune space pytest.skip("Nothing to test") - if sampler_type == "TorchSampler" and gather_context_logits: - pytest.skip("TorchSampler does not support gather_context_logits") - build_config = BuildConfig() build_config.gather_context_logits = gather_context_logits