[TRTLLM-7155][feat] Unify sampler handle logits implementation. (#6867)

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>
This commit is contained in:
Daniel Cámpora 2025-08-22 08:09:30 +02:00 committed by GitHub
parent 983dd7e57c
commit 099f081e03
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 201 additions and 144 deletions

View File

@ -16,7 +16,6 @@ from tensorrt_llm.mapping import CpType
from ..distributed import Distributed
from .llm_request import (ExecutorRequest, LlmRequest,
executor_request_to_llm_request)
from .sampler import Sampler, TorchSampler
SHUTDOWN_REQUEST_ID = -1
@ -707,21 +706,19 @@ class ExecutorRequestQueue:
def set_exclude_last_generation_logits(self,
disable_overlap_scheduler: bool,
sampler: Sampler) -> None:
pp_size: int) -> None:
# When overlap scheduler is enabled then when starting to handle a new prompt,
# sample_async is called twice before the first call to update_requests:
# - 1st time as a context request that handles on the 1st generated token
# - 2nd time as a generation request that handles on the 2nd generated token.
# and only after these two calls the sampler's update_request method is called.
# So in a sampler that works by the expected flow of handling the logits in
# sample_async (TorchSampler is an anomaly that instead does that on
# update_requests), every update_request doesn't handle the newest token, but one
# sample_async, every update_request doesn't handle the newest token, but one
# before it. Since all these calls work on the same request object, then its
# logits storage contains the logits of both the token update_requests should work
# on, and also its next token. Thus, excluding the last generation logits from any
# getter is required, when not using TorchSampler.
self.should_exclude_last_generation_logits = not disable_overlap_scheduler and not isinstance(
sampler, TorchSampler)
# getter is required.
self.should_exclude_last_generation_logits = not disable_overlap_scheduler and pp_size == 1
def _should_exclude_last_generation_logits(self) -> bool:
return self.should_exclude_last_generation_logits

View File

@ -1,3 +1,4 @@
from itertools import chain
from typing import List
import torch
@ -16,9 +17,9 @@ class HandleLogits:
context_requests: List[LlmRequest],
generation_requests: List[LlmRequest],
logits: torch.Tensor,
num_context_logits_prefix_sum: List[int],
max_num_sequences: int,
beam_width: int,
num_context_logits_prefix_sum: list[int],
is_generation_model: bool,
):
"""Handles context and generation logits for a batch of requests.
@ -26,10 +27,24 @@ class HandleLogits:
context_requests: List of context requests to process
generation_requests: List of generation requests to process
logits: Input logits tensor
num_context_logits_prefix_sum: Prefix sum of context logits for each request
max_num_sequences: Maximum number of sequences to process
beam_width: Beam width for the generation requests
num_context_logits_prefix_sum: Prefix sum of the logits
is_generation_model: Bool containing whether the model is generation or not
"""
if not any(r.py_return_context_logits or r.py_return_generation_logits
for r in chain(context_requests, generation_requests)):
return
if not is_generation_model:
for llm_req, logits_temp in zip(context_requests, logits):
if logits_temp.ndim == 1:
# For BERT: Add axis to be compatible with LogitsStorage
# (LogitsStorage will interpret this dim as the prompt_len which
# is not relevant for outputting logits of encoder only model).
logits_temp = logits_temp.unsqueeze(0)
llm_req.py_result.append_context_logits(logits_temp)
return
# Copy logits into decoderBuffers.logits
for batch_index, llm_req in enumerate(context_requests):
logits_begin = num_context_logits_prefix_sum[batch_index]

View File

@ -39,6 +39,7 @@ from ..models.modeling_utils import DecoderModelForCausalLM
from ..speculative.drafter import Drafter
from .executor_request_queue import ExecutorRequestQueue, RequestQueueItem
from .guided_decoder import GuidedDecoder
from .handle_logits import HandleLogits
from .kv_cache_transceiver import KvCacheTransceiver
from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState,
LlmResponse)
@ -244,7 +245,7 @@ class PyExecutor:
is_disaggregated=kv_cache_transceiver is not None,
)
self.executor_request_queue.set_exclude_last_generation_logits(
self.disable_overlap_scheduler, self.sampler)
self.disable_overlap_scheduler, self.dist.pp_size)
self.stats_lock = threading.Lock()
self.stats = []
@ -681,24 +682,6 @@ class PyExecutor:
self.response_cv.notify_all()
self.shutdown_event.set()
def _need_return_logits(self, scheduled_requests: ScheduledRequests):
for req in scheduled_requests.context_requests:
if req.py_return_context_logits:
return True
for req in scheduled_requests.generation_requests:
if req.py_return_generation_logits:
return True
return False
def _need_return_log_probs(self, scheduled_requests: ScheduledRequests):
for req in scheduled_requests.context_requests:
if req.py_return_log_probs:
return True
for req in scheduled_requests.generation_requests:
if req.py_return_log_probs:
return True
return False
def _executor_loop_pp(self):
logger.debug(f"Starting executor loop for pp_rank {self.dist.pp_rank}")
torch.cuda.set_device(self.device_id)
@ -790,10 +773,6 @@ class PyExecutor:
else:
with torch.cuda.nvtx.range("_forward_step_last_pp"):
batch_outputs = self._forward_step(scheduled_batch)
logits_host = None
if self._need_return_logits(scheduled_batch):
logits_host = batch_outputs["logits"].to(
"cpu", non_blocking=True)
if self.kv_cache_transceiver and self.guided_decoder:
self.guided_decoder.init_disagg_gen_requests(
scheduled_batch)
@ -802,7 +781,6 @@ class PyExecutor:
sample_state = self._sample_async(
scheduled_batch, batch_outputs)
sample_state.host.logits = logits_host
self._update_request_states(scheduled_batch)
if self.enable_iter_perf_stats:
@ -832,18 +810,10 @@ class PyExecutor:
torch.cuda.nvtx.range_push(
"_handle_new_tokens_inter_pp")
# Receive tokens from previous pp rank (w.r.t model forward direction)
(
logits,
sample_state.host,
) = self.dist.recv_object(
sample_state.host = self.dist.recv_object(
src=self.dist.prev_pp_rank,
tag=prev_microbatch_id,
)
if logits is not None:
logits_host = torch.from_numpy(logits)
sample_state.host.logits = logits_host
sample_state.device.logits = logits_host.to(
self.device_id)
else:
torch.cuda.nvtx.range_push("_handle_new_tokens_last_pp")
sample_state.sampler_event.synchronize()
@ -853,18 +823,9 @@ class PyExecutor:
if not self.dist.is_second_last_pp_rank:
if self.send_handles[prev_microbatch_id] is not None:
self.send_handles[prev_microbatch_id].wait()
needs_logits = (
self._need_return_logits(scheduled_batch)
or (self._need_return_log_probs(scheduled_batch)
and sample_state.host.log_probs is not None))
serialized_logits = sample_state.host.logits.numpy(
) if needs_logits else None
self.send_handles[
prev_microbatch_id] = self.dist.isend_object(
(
serialized_logits,
sample_state.host,
),
sample_state.host,
dest=self.dist.next_pp_rank,
tag=prev_microbatch_id)
torch.cuda.nvtx.range_pop()
@ -884,6 +845,40 @@ class PyExecutor:
previous_batch.scheduled_ctx_reqs)
self._handle_canceled_requests()
# If logits were requested last PP rank has to send to first PP rank (who sends responses) the
# logits of the requests that have finished.
# NOTE: If the rank processing the logits ever becomes the same as
# the rank sending the responses, this code can be removed.
finished_reqs = [
r for r in previous_batch.sample_state.
scheduled_requests.all_requests()
if r.state == LlmRequestState.GENERATION_COMPLETE
and (r.py_return_context_logits
or r.py_return_generation_logits)
]
if self.dist.is_first_pp_rank and len(finished_reqs):
finished_reqs_py_results = [
r.py_result for r in finished_reqs
]
finished_reqs_py_results = self.dist.recv_object(
src=self.dist.prev_pp_rank,
tag=prev_microbatch_id,
)
for req, py_result in zip(finished_reqs,
finished_reqs_py_results):
req.py_result = py_result
elif self.dist.is_last_pp_rank and len(finished_reqs):
if self.send_handles[
prev_microbatch_id] is not None:
self.send_handles[prev_microbatch_id].wait()
self.send_handles[
prev_microbatch_id] = self.dist.isend_object(
[r.py_result for r in finished_reqs],
dest=self.dist.next_pp_rank,
tag=prev_microbatch_id)
finished_requests = self._handle_responses()
previous_scheduled_batch = previous_batch.sample_state.scheduled_requests
self.resource_manager.update_resources(
@ -1538,7 +1533,22 @@ class PyExecutor:
batch_outputs) -> SampleState | None:
try:
if batch_outputs is not None:
return self.sampler.sample_async(scheduled_batch, batch_outputs)
num_context_logits_prefix_sum = [0]
prefix_sum = 0
for request in scheduled_batch.context_requests:
prefix_sum += request.context_chunk_size if request.py_return_context_logits else 1
num_context_logits_prefix_sum.append(prefix_sum)
HandleLogits()(scheduled_batch.context_requests,
scheduled_batch.generation_requests,
batch_outputs["logits"],
self.sampler.beam_width(
scheduled_batch.all_requests()),
num_context_logits_prefix_sum,
self.sampler.is_generation_model())
return self.sampler.sample_async(scheduled_batch, batch_outputs,
num_context_logits_prefix_sum)
except Exception as e:
traceback.print_exc()
error_msg = str(e)

View File

@ -5,7 +5,6 @@ from typing import List, Literal, Optional
import torch
from tensorrt_llm._torch.pyexecutor.handle_logits import HandleLogits
from tensorrt_llm._torch.pyexecutor.make_decoding_batch_input_output import \
MakeDecodingBatchInputOutput
from tensorrt_llm._utils import nvtx_range, torch_dtype_to_binding
@ -30,7 +29,6 @@ from .scheduler import ScheduledRequests
@dataclass(kw_only=True)
class SampleStateTensors:
new_tokens: torch.Tensor
logits: torch.Tensor | None = None
log_probs: torch.Tensor | None = None
def values(self):
@ -58,14 +56,24 @@ class Sampler(ABC):
return None
@abstractmethod
def sample_async(self, scheduled_requests: ScheduledRequests,
model_outputs) -> SampleState:
def sample_async(self, scheduled_requests: ScheduledRequests, model_outputs,
num_context_logits_prefix_sum: list[int]) -> SampleState:
raise NotImplementedError
@abstractmethod
def update_requests(self, state: SampleState) -> None:
raise NotImplementedError
@staticmethod
def beam_width(scheduled_requests: Iterable[LlmRequest]) -> int:
for req in scheduled_requests:
return req.sampling_config.beam_width
return 0
@abstractmethod
def is_generation_model(self) -> bool:
raise NotImplementedError
class EarlyStopSampler(Sampler):
"""
@ -73,10 +81,9 @@ class EarlyStopSampler(Sampler):
such as encoder-only model (e.g., BERT) or reward models that only need context phase.
"""
def sample_async(self, scheduled_requests: ScheduledRequests,
model_outputs) -> SampleState:
host = SampleStateTensors(logits=model_outputs['logits'],
new_tokens=torch.empty(0))
def sample_async(self, scheduled_requests: ScheduledRequests, model_outputs,
num_context_logits_prefix_sum: list[int]) -> SampleState:
host = SampleStateTensors(new_tokens=torch.empty(0))
return SampleState(scheduled_requests=scheduled_requests, host=host)
def update_requests(self, state: SampleState) -> None:
@ -87,14 +94,9 @@ class EarlyStopSampler(Sampler):
request.state = LlmRequestState.GENERATION_COMPLETE
# NOTE: This is a hack: set finish reason manually and set the beam 0
request.set_finished_reason(FinishReason.LENGTH, 0)
if request.py_return_context_logits:
logits = state.host.logits[idx]
if logits.ndim == 1:
# For BERT: Add axis to be compatible with LogitsStorage
# (LogitsStorage will interpret this dim as the prompt_len which
# is not relevant for outputting logits of encoder only model).
logits = logits.unsqueeze(0)
request.py_result.append_context_logits(logits)
def is_generation_model(self) -> bool:
return False
@dataclass(kw_only=True)
@ -117,8 +119,10 @@ class EarlyStopWithMMResult(Sampler):
Use for skipping decoding step for non generation model, and return the batch_output (such as mm_embeddings)
"""
def sample_async(self, scheduled_requests: ScheduledRequests,
model_outputs) -> SampleStateWithMMResult:
def sample_async(
self, scheduled_requests: ScheduledRequests, model_outputs,
num_context_logits_prefix_sum: list[int]
) -> SampleStateWithMMResult:
# from model_outputs to MultimodalResult
data = MultimodalResult(mm_embeddings=model_outputs['mm_embeddings'])
return SampleStateWithMMResult(scheduled_requests=scheduled_requests,
@ -141,6 +145,9 @@ class EarlyStopWithMMResult(Sampler):
request.py_result.append_mm_embeddings(mm_embedding)
def is_generation_model(self) -> bool:
return False
def top_k_sampling_batch(logits,
top_k=50,
@ -352,6 +359,9 @@ class TorchSampler(Sampler):
BEAM = 0
MAX_BEAM_WIDTH = BEAM + 1
def is_generation_model(self) -> bool:
return True
@dataclass(frozen=True, kw_only=True)
class Store:
new_tokens: torch.Tensor
@ -445,13 +455,9 @@ class TorchSampler(Sampler):
return False
def handle_logits(self, request: LlmRequest, state: SampleState, *,
beam: int, count: int):
def handle_logprobs(self, request: LlmRequest, state: SampleState, *,
beam: int, count: int):
current_slice = slice(0, count), request.py_seq_slot, beam
if request.py_return_generation_logits:
assert state.host.logits is not None
current_logits = state.host.logits[current_slice]
request.py_result.append_generation_logits(current_logits)
if request.py_return_log_probs:
assert state.host.log_probs is not None
log_probs = state.host.log_probs[request.py_seq_slot][beam][:count]
@ -546,7 +552,7 @@ class TorchSampler(Sampler):
continue
new_token = add_token(req, new_tokens, beam=self.BEAM)
self._handle_stop_criteria(req, new_token)
self.handle_logits(req, state, beam=self.BEAM, count=1)
self.handle_logprobs(req, state, beam=self.BEAM, count=1)
req.py_decoding_iter += 1
for req in state.scheduled_requests.generation_requests:
@ -558,37 +564,28 @@ class TorchSampler(Sampler):
req.py_num_accepted_draft_tokens = num_accepted
req.py_rewind_len = req.py_draft_pages_allocated - num_accepted
processed += num_accepted
self.handle_logits(req, state, beam=self.BEAM, count=processed)
self.handle_logprobs(req, state, beam=self.BEAM, count=processed)
req.py_decoding_iter += 1
def log_probs_host(self, requests: Iterable[LlmRequest]):
def log_probs_host(self, scheduled_requests: ScheduledRequests):
"""Shape: In lockstep with TRTLLMSampler: https://github.com/NVIDIA/TensorRT-LLM/blob/cea5dd1e3883b18bf50901a7f196f50a9544c28c/cpp/include/tensorrt_llm/runtime/decoderState.h#L103"""
if any(req.py_return_log_probs for req in requests):
if any(req.py_return_log_probs
for req in scheduled_requests.all_requests()):
return torch.empty(
(self.max_num_sequences, self.MAX_BEAM_WIDTH, self.max_tokens),
device="cpu",
pin_memory=True)
return None
def gen_logits_host(self, requests: Iterable[LlmRequest], vocab_size: int):
if any(req.py_return_generation_logits for req in requests):
return torch.empty((self.max_tokens, self.max_num_sequences,
self.MAX_BEAM_WIDTH, vocab_size),
device="cpu",
pin_memory=True)
return None
def sample_async(self, scheduled_requests: ScheduledRequests,
model_outputs: dict[str, torch.Tensor]) -> SampleState:
requests = scheduled_requests.all_requests()
model_outputs: dict[str, torch.Tensor],
num_context_logits_prefix_sum: list[int]) -> SampleState:
new_tokens = self.store.new_tokens
vocab_size = model_outputs["logits"].shape[-1]
log_probs_host = self.log_probs_host(requests)
gen_logits_host = self.gen_logits_host(requests, vocab_size)
self._process_requests(requests,
log_probs_host = self.log_probs_host(scheduled_requests)
self._process_requests(scheduled_requests,
model_outputs,
new_tokens,
gen_logits_host=gen_logits_host,
num_context_logits_prefix_sum,
log_probs_host=log_probs_host)
new_tokens_host = new_tokens.to(device="cpu", non_blocking=True)
sampler_event = torch.cuda.Event()
@ -596,8 +593,7 @@ class TorchSampler(Sampler):
return SampleState(scheduled_requests=scheduled_requests,
device=SampleStateTensors(new_tokens=new_tokens),
host=SampleStateTensors(new_tokens=new_tokens_host,
log_probs=log_probs_host,
logits=gen_logits_host),
log_probs=log_probs_host),
sampler_event=sampler_event)
@staticmethod
@ -659,19 +655,37 @@ class TorchSampler(Sampler):
return logits
def _process_requests(self,
requests: list[LlmRequest],
scheduled_requests: ScheduledRequests,
model_outputs: dict[str, torch.Tensor],
new_tokens: torch.Tensor,
num_context_logits_prefix_sum: list[int],
*,
gen_logits_host: torch.Tensor | None = None,
log_probs_host: torch.Tensor | None = None):
beam_width = self.MAX_BEAM_WIDTH
beam = self.BEAM
raw_logits = model_outputs["logits"]
# raw_logits should contain only the logits from the gen requests.
# If return context logits is requested, fetch only the logits from gen requests.
if any(r.py_return_context_logits
for r in scheduled_requests.context_requests):
gen_logits_indices = []
total_context_logits = num_context_logits_prefix_sum[-1]
for i in range(len(scheduled_requests.context_requests)):
gen_logits_indices.append(num_context_logits_prefix_sum[i + 1] -
1)
gen_logits_indices.extend(
range(
total_context_logits, total_context_logits +
len(scheduled_requests.generation_requests)))
raw_logits = model_outputs["logits"][gen_logits_indices]
else:
raw_logits = model_outputs["logits"]
requests = scheduled_requests.all_requests()
num_steps = [1 + get_draft_token_length(req) for req in requests]
sum_steps = sum(num_steps)
no_draft_tokens = len(requests) == sum_steps
fast_path = not self.enable_mixed_sampler and no_draft_tokens and gen_logits_host is None and log_probs_host is None
fast_path = not self.enable_mixed_sampler and no_draft_tokens and log_probs_host is None
seq_slots_host = torch.as_tensor([r.py_seq_slot for r in requests])
seq_slots = seq_slots_host.to(device="cuda", non_blocking=True)
@ -727,8 +741,6 @@ class TorchSampler(Sampler):
new_tokens[current_slice] = next_tokens
if request.py_draft_logits is not None:
request.py_target_probs = softmax.clone()
if gen_logits_host is not None:
gen_logits_host[current_slice].copy_(logits, non_blocking=True)
if log_probs_host is not None:
assert beam == 0, "The following call relies on beam_width to be 1 - hence the unsqueeze"
token_probs = torch.gather(
@ -769,6 +781,9 @@ class TRTLLMSampler(Sampler):
MAX_DECODING_TOKENS = 1 # It must be 1 when not in speculative decoding
SampleState = SampleStateTRTLLM
def is_generation_model(self) -> bool:
return True
def __init__(
self,
executor_config: ExecutorConfig,
@ -864,7 +879,6 @@ class TRTLLMSampler(Sampler):
speculative_decoding_fast_logits=False,
is_leader_in_orch_mode=False,
is_normalize_log_probs=False)
self.algs.handle_logits = HandleLogits()
self.algs.make_decoding_batch_input_output = MakeDecodingBatchInputOutput(
)
@ -898,13 +912,6 @@ class TRTLLMSampler(Sampler):
slots = torch.tensor([r.py_seq_slot for r in adp], dtype=torch.int32)
self.algs.decoder.underlying_decoder().setup(config, batch_size, slots)
@staticmethod
@torch.inference_mode()
def beam_width(scheduled_requests: Iterable[LlmRequest]) -> int:
for req in scheduled_requests:
return req.sampling_config.beam_width
return 0
def get_cache_indirection(self) -> torch.Tensor | None:
return self.store["decoder_state"].cache_indirection_output
@ -920,8 +927,9 @@ class TRTLLMSampler(Sampler):
@torch.inference_mode()
@nvtx_range("sample_async")
def sample_async(self, scheduled_requests: ScheduledRequests,
model_outputs) -> SampleStateTRTLLM:
def sample_async(
self, scheduled_requests: ScheduledRequests, model_outputs,
num_context_logits_prefix_sum: list[int]) -> SampleStateTRTLLM:
batch_size = scheduled_requests.batch_size
beam_width = self.beam_width(scheduled_requests.all_requests())
@ -934,29 +942,10 @@ class TRTLLMSampler(Sampler):
self.setup_sampler_step(scheduled_requests)
num_context_logits_prefix_sum = [0]
prefix_sum = 0
for request in scheduled_requests.context_requests:
prefix_sum += request.context_chunk_size if request.py_return_context_logits else 1
num_context_logits_prefix_sum.append(prefix_sum)
if any(r.py_return_context_logits or r.py_return_generation_logits
for r in scheduled_requests.all_requests()):
self.algs.handle_logits(scheduled_requests.context_requests,
scheduled_requests.generation_requests,
model_outputs["logits"],
num_context_logits_prefix_sum,
self.max_num_sequences, beam_width)
# For beam search, cache indirection needs to be updated
if beam_width > 1:
self._update_cache_indirection_buffer(scheduled_requests)
# TODO: Enable this back once nanobind is merged and/or llm request is a pure python object
# decoding_input = self.algs.make_decoding_batch_input_output(
# scheduled_requests, model_outputs["logits"], beam_width,
# num_context_logits_prefix_sum)
self.store["decoding_input"][
self.micro_batch_idx] = make_decoding_batch_input(
scheduled_requests.context_requests,

View File

@ -9,6 +9,7 @@ from tensorrt_llm._utils import nvtx_range
from tensorrt_llm.logger import logger
from ..pyexecutor.guided_decoder import GuidedDecoder
from ..pyexecutor.handle_logits import HandleLogits
from ..pyexecutor.llm_request import (LlmRequest, LlmRequestState,
get_draft_token_length)
from ..pyexecutor.resource_manager import BaseResourceManager, ResourceManager
@ -266,7 +267,21 @@ class ModelDrafter(Drafter):
"""Sample tokens from draft model outputs."""
try:
if self.sampler is not None:
return self.sampler.sample_async(draft_batch, outputs)
num_context_logits_prefix_sum = [0]
prefix_sum = 0
for request in draft_batch.context_requests:
prefix_sum += request.context_chunk_size if request.py_return_context_logits else 1
num_context_logits_prefix_sum.append(prefix_sum)
HandleLogits()(
draft_batch.context_requests,
draft_batch.generation_requests, outputs["logits"],
self.sampler.beam_width(draft_batch.all_requests()),
num_context_logits_prefix_sum,
self.sampler.is_generation_model())
return self.sampler.sample_async(draft_batch, outputs,
num_context_logits_prefix_sum)
return None
except Exception as e:
logger.error(f"Error in sampling: {str(e)}")

View File

@ -268,8 +268,10 @@ class MTPSampler(TorchSampler):
req.py_rewind_len = self.draft_len - (num_new_tokens - 1)
self._request_common_handling(req, next_draft_tokens_list)
def sample_async(self, scheduled_requests: ScheduledRequests,
outputs: dict[str, torch.Tensor]) -> SampleStateMTP:
def sample_async(
self, scheduled_requests: ScheduledRequests,
outputs: dict[str, torch.Tensor],
num_context_logits_prefix_sum: list[int]) -> SampleStateMTP:
# new_tokens_device: accepted tokens, device tensor, shape: batch_size, nextn + 1
# new_tokens_lens_device: accepted lengths, device tensor, shape: batch_size
# next_draft_tokens_device: predicted draft tokens, device tensor, shape: batch_size, nextn

View File

@ -15,6 +15,7 @@
import os
import pytest
import torch
from defs.conftest import get_sm_version
from tensorrt_llm import LLM
@ -398,6 +399,40 @@ class TestLlama3_2_1B(LlmapiAccuracyTestHarness):
task = CnnDailymail(self.MODEL_NAME)
task.evaluate(llm)
@skip_pre_hopper
@pytest.mark.skip_less_device(4)
@pytest.mark.parametrize("disable_overlap_scheduler", [True, False])
@pytest.mark.parametrize("pp_size", [2, 4], ids=["pp2", "pp4"])
def test_return_logits_pp(self, pp_size, disable_overlap_scheduler):
prompts = ["A B C"]
llm = LLM(model=self.MODEL_PATH,
pipeline_parallel_size=pp_size,
disable_overlap_scheduler=disable_overlap_scheduler)
sampling_params = SamplingParams(max_tokens=8,
return_context_logits=True,
return_generation_logits=True,
logprobs=True)
with llm:
for output in llm.generate(prompts,
sampling_params=sampling_params):
assert output.context_logits is not None
# NOTE: prompt_token_ids of "A B C" becomes [1, 319, 350, 315]
expected_len = len(prompts[0].split()) + 1
assert expected_len == output.context_logits.shape[0]
gen_logits = output.outputs[0].generation_logits
assert gen_logits is not None
assert gen_logits.ndim == 2
assert gen_logits.shape[0] == sampling_params.max_tokens
assert torch.argmax(
gen_logits, dim=1).tolist() == output.outputs[0].token_ids
assert len(
output.outputs[0].logprobs) == sampling_params.max_tokens
class TestLlama3_2_3B(LlmapiAccuracyTestHarness):
MODEL_NAME = "meta-llama/Llama-3.2-3B"

View File

@ -27,9 +27,6 @@ def test_generate_with_return_logits(disable_overlap_scheduler: bool,
or return_log_probs): # prune space
pytest.skip("Nothing to test")
if sampler_type == "TorchSampler" and gather_context_logits:
pytest.skip("TorchSampler does not support gather_context_logits")
build_config = BuildConfig()
build_config.gather_context_logits = gather_context_logits
@ -94,9 +91,6 @@ def test_generate_async_with_return_logits(disable_overlap_scheduler: bool,
or return_log_probs): # prune space
pytest.skip("Nothing to test")
if sampler_type == "TorchSampler" and gather_context_logits:
pytest.skip("TorchSampler does not support gather_context_logits")
build_config = BuildConfig()
build_config.gather_context_logits = gather_context_logits