[TRTLLM-9735][feat] Add processed logprobs functionality to TorchSampler (#9675)

Signed-off-by: Stefan Niebler <82932102+stnie@users.noreply.github.com>
Signed-off-by: Yuan Tong <13075180+tongyuantongyu@users.noreply.github.com>
Signed-off-by: Erin Ho <14718778+hchings@users.noreply.github.com>
Co-authored-by: Yuan Tong <13075180+tongyuantongyu@users.noreply.github.com>
Co-authored-by: Erin Ho <14718778+hchings@users.noreply.github.com>
This commit is contained in:
Stefan Niebler 2026-01-16 19:52:41 +01:00 committed by GitHub
parent cfebfbb505
commit 0cfd08745c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 1349 additions and 440 deletions

View File

@ -1020,7 +1020,6 @@ common-files: &common_files |
tests/unittest/_torch/ray_orchestrator/single_gpu/test_cache_transceiver_comm.py |
tests/unittest/_torch/sampler/test_beam_search.py |
tests/unittest/_torch/sampler/test_best_of_n.py |
tests/unittest/_torch/sampler/test_return_logits.py |
tests/unittest/_torch/sampler/test_torch_multi_arange.py |
tests/unittest/_torch/sampler/test_trtllm_sampler.py |
tests/unittest/_torch/speculative/test_draft_target.py |

View File

@ -1061,7 +1061,6 @@ exclude = [
"tests/unittest/_torch/ray_orchestrator/single_gpu/test_cache_transceiver_comm.py",
"tests/unittest/_torch/sampler/test_beam_search.py",
"tests/unittest/_torch/sampler/test_best_of_n.py",
"tests/unittest/_torch/sampler/test_return_logits.py",
"tests/unittest/_torch/sampler/test_torch_multi_arange.py",
"tests/unittest/_torch/sampler/test_trtllm_sampler.py",
"tests/unittest/_torch/speculative/test_draft_target.py",

View File

@ -8,6 +8,7 @@ import tensorrt_llm.bindings
from tensorrt_llm._torch.shared_tensor import SharedTensorContainer
from tensorrt_llm.bindings import executor as tllm_executor
from tensorrt_llm.executor.result import TokenLogprobs
from tensorrt_llm.sampling_params import LogprobMode
SamplingConfig = tensorrt_llm.bindings.SamplingConfig
'''
@ -485,6 +486,7 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
is_first_draft: bool = False,
use_chunked_generation_logits: bool = True,
logits_chunk_size: int = 8,
logprobs_mode: LogprobMode = LogprobMode.RAW,
**kwargs):
self.py_logits_post_processors = kwargs.pop("py_logits_post_processors",
@ -566,6 +568,9 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
# currently, keep py_stop_words_list as python list, rather than tensor.
self.py_stop_words_list = stop_words_list
self.py_logprobs_mode = LogprobMode(
logprobs_mode) # handle passed a raw string
self.py_result = PyResult(
prompt_len=self.py_prompt_len,
max_new_tokens=self.py_max_new_tokens,
@ -825,7 +830,10 @@ def executor_request_to_llm_request(
arrival_time=getattr(executor_request, "py_arrival_time", None),
py_multimodal_data=getattr(executor_request, "py_multimodal_data",
None),
kv_cache_retention_config=executor_request.kv_cache_retention_config)
kv_cache_retention_config=executor_request.kv_cache_retention_config,
logprobs_mode=getattr(executor_request, "py_logprobs_mode",
LogprobMode.RAW),
)
if child_req_ids:
for child_id in child_req_ids:
llm_request.create_child_request(child_id)

File diff suppressed because it is too large Load Diff

View File

@ -266,7 +266,8 @@ def greedy_search_sampling_batch(
next_tokens = torch.argmax(logits, dim=-1)
softmax: Optional[torch.Tensor] = None
if return_probs:
softmax = torch.softmax(logits, dim=-1)
softmax = torch.zeros_like(logits)
softmax.scatter_(1, next_tokens.unsqueeze(-1), 1.0)
return next_tokens, softmax
@ -471,10 +472,10 @@ def sample(
strategy: Strategy,
logits: torch.Tensor,
*,
generator: Optional[torch.Generator] = None,
generator: torch.Generator | None = None,
group_metadata: StrategyMetadata | None = None,
return_probs: bool = True,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, torch.Tensor | None, float | None]:
match strategy:
case ("top_k", top_k, temperature):
tokens, softmax = top_k_sampling_batch(
@ -506,6 +507,7 @@ def sample(
)
case ("greedy", None):
tokens, softmax = greedy_search_sampling_batch(logits, return_probs=return_probs)
temperature = None
case ("beam_search", beam_width_in, beam_width_out, temperature):
assert group_metadata is not None and isinstance(group_metadata, BeamSearchMetadata), (
"BeamSearchMetadata is required for beam_search_sampling_batch"
@ -519,7 +521,7 @@ def sample(
generator=generator,
return_probs=return_probs,
)
return tokens, softmax
return tokens, softmax, temperature
GenericStrategyKeyType = TypeVar("GenericStrategyKeyType")
@ -545,11 +547,11 @@ class GroupedStrategySampler(Generic[GenericStrategyKeyType], abc.ABC):
strategies: list[Strategy],
logits: torch.Tensor,
*,
group_logit_indices: Optional[torch.Tensor] = None,
generator: Optional[torch.Generator] = None,
group_logit_indices: torch.Tensor | None = None,
generator: torch.Generator | None = None,
return_probs: bool,
group_metadata: StrategyMetadata | None = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, torch.Tensor | None, float | torch.Tensor | None]:
raise NotImplementedError
@ -579,11 +581,11 @@ class SimpleGroupedStrategySampler(GroupedStrategySampler[Strategy]):
strategies: list[Strategy],
logits: torch.Tensor,
*,
group_logit_indices: Optional[torch.Tensor] = None,
generator: Optional[torch.Generator] = None,
group_logit_indices: torch.Tensor | None = None,
generator: torch.Generator | None = None,
return_probs: bool,
group_metadata: StrategyMetadata | None = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, torch.Tensor | None, float | None]:
if group_key[0] == "beam_search":
beam_width_in = group_key[1]
else:

View File

@ -141,8 +141,9 @@ class _StrategyImpls:
*,
group_logit_indices: Optional[torch.Tensor],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
probs = self._prepare_probs_with_temperature(logits, group_logit_indices, None)
new_tokens, _ = greedy_search_sampling_batch(probs, return_probs=False)
if group_logit_indices is not None:
logits = torch.index_select(logits, 0, group_logit_indices) # ensures copy
new_tokens, probs = greedy_search_sampling_batch(logits, return_probs=True)
return new_tokens, probs
@classmethod
@ -240,6 +241,9 @@ class _StrategyImpls:
return True
class GreedyWithProbs(StrategyImplWithProbs):
def __init__(self):
self._temperature = None
@override
@classmethod
def from_strategies(
@ -425,6 +429,9 @@ class _StrategyImpls:
return False
class GreedySampleOnly(StrategyImplSampleOnly):
def __init__(self):
self._temperature = None
@override
@classmethod
def from_strategies(
@ -722,7 +729,7 @@ class FlashInferGroupedStrategySampler(GroupedStrategySampler[Type[_StrategyImpl
generator: Optional[torch.Generator] = None,
return_probs: bool,
group_metadata: StrategyMetadata | None = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
if hasattr(group_key, "static_beam_width_in"):
beam_width_in = group_key.static_beam_width_in
else:
@ -735,9 +742,16 @@ class FlashInferGroupedStrategySampler(GroupedStrategySampler[Type[_StrategyImpl
assert return_probs == group_key.computes_probs()
strategy_impl_cls = group_key
return strategy_impl_cls.from_strategies(strategies, cuda_device=logits.device).sample(
sampling_object = strategy_impl_cls.from_strategies(strategies, cuda_device=logits.device)
next_tokens, softmax = sampling_object.sample(
logits,
group_logit_indices=group_logit_indices,
generator=generator,
group_metadata=group_metadata,
)
temperature = (
sampling_object._temperature.unsqueeze(-1)
if sampling_object._temperature is not None
else None
)
return next_tokens, softmax, temperature

View File

@ -215,7 +215,7 @@ class MTPSampler(TorchSampler):
SampleState = SampleStateMTP
@dataclass(frozen=True, kw_only=True)
@dataclass(kw_only=True)
class Store(TorchSampler.Store):
new_tokens: torch.Tensor
next_new_tokens: torch.Tensor

View File

@ -562,6 +562,7 @@ class BaseWorker(GenerationExecutor):
cache_salt_id=request.cache_salt_id)
executor_request.py_num_logprobs = request.sampling_params.logprobs
executor_request.py_lora_path = py_lora_path
executor_request.py_logprobs_mode = request.sampling_params.logprobs_mode
if self._is_pytorch_backend and request.multimodal_params is not None:
if request.multimodal_params.multimodal_data is not None:

View File

@ -221,7 +221,7 @@ class GenerationExecutor(ABC):
self, request: GenerationRequest) -> Optional[LogprobParams]:
"""Store logprobs-related fields from request for the later logprob calculation."""
logprob_params = None
if request.sampling_params.logprobs or request.sampling_params.prompt_logprobs:
if request.sampling_params.logprobs is not None or request.sampling_params.prompt_logprobs:
logprob_params = LogprobParams(
logprobs=request.sampling_params.logprobs,
prompt_logprobs=request.sampling_params.prompt_logprobs,

View File

@ -933,6 +933,21 @@ def compute_logprobs(
logits = logits[:len(tokens)]
logprobs = F.log_softmax(logits.to("cuda", dtype=torch.float32), dim=-1)
# only return sampled token
if top_k == 0:
results: TokenLogprobs = []
if tokens is not None:
for t in range(logprobs.size(0)):
token_id = tokens[t]
token_logprob = logprobs[t, token_id].item()
rank = (logprobs[t] > token_logprob).sum().item() + 1
token_dict = {
token_id: Logprob(logprob=token_logprob, rank=rank)
}
results.append(token_dict)
return results
topk_vals, topk_indices = torch.topk(logprobs, k=top_k, dim=-1)
results: TokenLogprobs = []
@ -961,7 +976,7 @@ def compute_logprobs(
None) if k_prompt_logprobs and context_logits is not None else None
generation_logprobs = _topk_logprobs(
generation_logits, k_logprobs, output_token_ids
) if k_logprobs and generation_logits is not None else None
) if k_logprobs is not None and generation_logits is not None else None
return LogProbsResult(prompt=prompt_logprobs,
generation=generation_logprobs)

View File

@ -666,7 +666,7 @@ class BaseLLM:
if sampling_params.prompt_logprobs and not sampling_params.return_context_logits:
sampling_params.return_context_logits = True
sampling_params._context_logits_auto_enabled = True
if sampling_params.logprobs and not sampling_params.return_generation_logits:
if sampling_params.logprobs is not None and not sampling_params.return_generation_logits:
sampling_params.return_generation_logits = True
sampling_params._generation_logits_auto_enabled = True
@ -737,7 +737,7 @@ class BaseLLM:
f"Example: LLM(..., build_config=BuildConfig(gather_context_logits=True))."
)
if sampling_params.logprobs and not self.args.gather_generation_logits:
if sampling_params.logprobs is not None and not self.args.gather_generation_logits:
raise ValueError(
f"`sampling_params.logprobs={sampling_params.logprobs}` requires `gather_generation_logits=True` "
f"to be passed explicitly to the `LLM()` constructor.")

View File

@ -6,6 +6,7 @@ from typing import List, NamedTuple, Optional, Tuple, Union
import torch
from pydantic import BaseModel
from strenum import StrEnum
from tensorrt_llm.bindings import executor as tllme
from tensorrt_llm.logger import logger
@ -46,6 +47,18 @@ class LogprobParams(NamedTuple):
drop_generation_logits: bool = False
class LogprobMode(StrEnum):
RAW = "raw"
"""
Return the raw log probabilities, i.e., the log probabilities calculated directly from the model output logits.
"""
PROCESSED = "processed"
"""
Return the processed log probabilities, i.e., the log probabilities after applying sampling parameters,
such as temperature, top-k, top-p, etc.
"""
class LogitsProcessor(ABC):
"""Base class for logits processor.
@ -172,7 +185,9 @@ class SamplingParams:
min_p (float, optional): scale the most likely token to determine the minimum token probability. None means using C++ runtime default 0.0. Defaults to None.
beam_width_array (List[int], optional): The array of beam width using in Variable-Beam-Width-Search. Defaults to None.
logprobs (int, optional): Number of log probabilities to return per output token. Defaults to None.
logprobs (int, optional): Number of log probabilities to return per output token. When set to 0, return only the sampled token's log probability.
When set to K>0, return top-K log probabilities + the sampled token's log probability (last entry) if it's not in the Top-K. Defaults to None.
logprobs_mode (LogprobMode): The mode of log probabilities to return. Defaults to LogprobMode.RAW.
prompt_logprobs (int, optional): Number of log probabilities to return per prompt token. Defaults to None.
return_context_logits (bool): Controls if Result should contain the context logits. Defaults to False.
return_generation_logits (bool): Controls if Result should contain the generation logits. Defaults to False.
@ -219,6 +234,7 @@ class SamplingParams:
n: int = 1
best_of: Optional[int] = None
use_beam_search: bool = False
logprobs_mode: LogprobMode = LogprobMode.RAW
# Keep the below fields in sync with tllme.SamplingConfig or maintin the mapping table.
top_k: Optional[int] = None
@ -321,6 +337,8 @@ class SamplingParams:
f"under the greedy decoding."
)
self.logprobs_mode = LogprobMode(self.logprobs_mode)
if self.truncate_prompt_tokens is not None and self.truncate_prompt_tokens < 1:
raise ValueError(
f"truncate_prompt_tokens must be >= 1, got {self.truncate_prompt_tokens}"
@ -329,8 +347,11 @@ class SamplingParams:
if self.guided_decoding is not None:
self.guided_decoding._validate()
# correct types as users might pass in logprob=True for Top-1 logprobs
self.logprobs = self.logprobs and int(self.logprobs)
# correct types as users might pass in logprob=True for Top-0 logprobs and logprobs=False for no logprobs
if self.logprobs is False:
self.logprobs = None
if self.logprobs is True:
self.logprobs = 0
self.prompt_logprobs = self.prompt_logprobs and int(self.prompt_logprobs)
# NB: Static, because downstream code only holds instances of
@ -494,7 +515,7 @@ class SamplingParams:
config_kwargs = {f: getattr(self, f) for f in fields}
if is_pytorch_backend:
config_kwargs["return_log_probs"] = bool(self.logprobs)
config_kwargs["return_log_probs"] = self.logprobs is not None
if self.prompt_logprobs and not self.return_context_logits:
logger.info(
"Since prompt_logprobs is requested but return_context_logits is False, "

View File

@ -21,7 +21,7 @@ l0_a30:
- unittest/_torch/modeling -k "modeling_out_of_tree"
- unittest/_torch/modeling -k "modeling_starcoder2"
- unittest/_torch/sampler/test_beam_search.py
- unittest/_torch/sampler/test_return_logits.py
- unittest/_torch/sampler/test_logits_logprobs.py
- test_e2e.py::test_openai_completions_with_logit_bias[torch_sampler]
- test_e2e.py::test_openai_chat_with_logit_bias[torch_sampler]
- test_e2e.py::test_openai_completions_with_logit_bias[trtllm_sampler]

View File

@ -284,6 +284,8 @@ test_e2e.py::test_ptp_quickstart_advanced_2gpus_sm120[Nemotron-Super-49B-v1-BF16
triton_server/test_triton.py::test_gpt_speculative_decoding[gpt-speculative-decoding] SKIP (https://nvbugs/5762854)
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B_Instruct_RocketKV::test_auto_dtype SKIP (https://nvbugs/5762822)
unittest/_torch/sampler/test_return_logits.py SKIP (https://nvbugs/5764627)
unittest/_torch/sampler/test_logits_logprobs.py::test_generate_with_return_logits SKIP (https://nvbugs/5764627)
unittest/_torch/sampler/test_logits_logprobs.py::test_generate_async_with_return_logits SKIP (https://nvbugs/5764627)
examples/serve/test_serve.py::test_config_file_loading[--config] SKIP (https://nvbugs/5754977)
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp2pp2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugspro.nvidia.com/bug/5794313)
examples/test_ray.py::test_ray_disaggregated_serving[tp2] SKIP (https://nvbugs/5612502)

View File

@ -125,7 +125,7 @@ def check_generation_logits(beam: CompletionOutput,
def check_logprobs(beam: CompletionOutput, sampling_params: SamplingParams,
valid_tokens: int | None) -> None:
"""Check if the logprobs have the correct shape"""
if sampling_params.logprobs:
if sampling_params.logprobs is not None:
generated_tokens = valid_tokens if valid_tokens is not None else sampling_params.max_tokens
assert len(
beam.logprobs
@ -345,7 +345,7 @@ class GeneralTestParams:
prompt_len = len(input_tokens)
num_generated_tokens = 5
seq_len = prompt_len + num_generated_tokens
num_logprobs = 1
num_logprobs = 0
seq_slot = 4
end_id = 99
batch_size = 2
@ -541,7 +541,7 @@ def create_default_request(test_params: GeneralTestParams) -> LlmRequest:
end_id=test_params.end_id,
sampling_config=SamplingConfig(
sampling_params._get_sampling_config()),
return_log_probs=test_params.num_logprobs > 0,
return_log_probs=test_params.num_logprobs >= 0,
num_logprobs=test_params.num_logprobs,
is_streaming=False)
@ -590,7 +590,7 @@ def test_create_beam_history():
num_generated_tokens = test_params.num_generated_tokens
seq_slot = test_params.seq_slot
vocab_size = test_params.vocab_size
num_logprobs = test_params.num_logprobs
num_logprobs = test_params.num_logprobs + 1
cache_indirection = sampler.store.cache_indirection
original_tokens = sampler.store.original_tokens
original_logprobs = torch.zeros(
@ -635,7 +635,11 @@ def test_create_beam_history():
# set the logprobs in the request:
token_logprobs = sampler._convert_logprobs_tensor_to_list(
original_logprob_indices[:beam_width, :num_generated_tokens - 1],
original_logprobs[:beam_width, :num_generated_tokens - 1])
original_logprobs[:beam_width, :num_generated_tokens - 1],
None,
None,
None,
)
request.py_result.set_log_probs(
token_logprobs,
cum_log_probs=torch.zeros_like(
@ -657,9 +661,10 @@ def test_create_beam_history():
) > 0, "Deterministic offsets must not only contain zeros. Otherwise change the seed."
# set the new log probs and tokens for the beam search sampling
sampler.store.new_log_probs[
sampler.store.sampled_log_probs[
seq_slot, :beam_width] = original_logprobs[:beam_width,
num_generated_tokens - 1, 0]
num_generated_tokens - 1,
0:1]
sampler.store.new_tokens[
0,
seq_slot, :beam_width] = original_logprob_indices[:beam_width,

View File

@ -0,0 +1,589 @@
import os
import pytest
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils.llm_data import llm_models_root
from utils.util import force_ampere
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm._torch.pyexecutor.sampling_utils import top_k_top_p_sampling_batch
from tensorrt_llm._torch.pyexecutor.sampling_utils_flashinfer import _StrategyImpls
from tensorrt_llm.llmapi.llm_utils import KvCacheConfig
prompts = ["A B C"]
global_kvcache_config = KvCacheConfig(
max_tokens=10000,
enable_block_reuse=True,
)
@pytest.fixture(scope="module", params=[False, True])
def gather_generation_logits_fixture(request) -> bool:
return request.param
@pytest.fixture(scope="module", params=[False, True])
def gather_context_logits_fixture(request) -> bool:
return request.param
@pytest.fixture(scope="module", params=[False, True])
def disable_overlap_scheduler_fixture(request) -> bool:
return request.param
@pytest.fixture(scope="module", params=["TRTLLMSampler", "TorchSampler"])
def sampler_type_fixture(request) -> str:
return request.param
class CacheSalter:
_salt = 0
@classmethod
def get_salt_unique(cls) -> str:
cls._salt += 1
return str(cls._salt)
@classmethod
def get_salt_shared(cls) -> str:
return str(0)
@classmethod
def get_salt(cls, reuse_cache: bool) -> str:
if reuse_cache:
salt = cls.get_salt_shared()
else:
salt = cls.get_salt_unique()
return salt
@pytest.fixture(scope="module")
def llm(
gather_context_logits_fixture: bool,
gather_generation_logits_fixture: bool,
sampler_type_fixture: str,
disable_overlap_scheduler_fixture: bool,
):
gather_generation_logits = gather_generation_logits_fixture
sampler_type = sampler_type_fixture
disable_overlap_scheduler = disable_overlap_scheduler_fixture
llm = LLM(
model=os.path.join(llm_models_root(), "llama-models-v2", "TinyLlama-1.1B-Chat-v1.0"),
kv_cache_config=global_kvcache_config,
gather_generation_logits=gather_generation_logits,
max_batch_size=128, # reduce buffer sizes, specially for generation logits
sampler_type=sampler_type,
disable_overlap_scheduler=disable_overlap_scheduler,
)
# FIXME: Sometimes LLM shutdown hangs, might be related to https://nvbugs/5577178.
# Remove patch below once fixed.
old_exit = LLM.__exit__
def _exit_with_xfail_on_timeout(self, exc_type, exc_value, traceback) -> bool:
import _pytest.outcomes
try:
return old_exit(self, exc_type, exc_value, traceback)
except _pytest.outcomes.Failed as e:
if e.msg and "pytest-timeout" in e.msg.lower():
pytest.xfail("Known LLM shutdown issue (https://nvbugs/5577178).")
else:
raise
with pytest.MonkeyPatch.context() as patch:
patch.setattr(LLM, "__exit__", _exit_with_xfail_on_timeout)
with llm:
yield llm
@pytest.fixture(scope="module", params=[False, True])
def simple_llm(request) -> LLM:
disable_flashinfer_sampling = request.param
llm = LLM(
model=os.path.join(llm_models_root(), "llama-models-v2", "TinyLlama-1.1B-Chat-v1.0"),
max_batch_size=8,
disable_flashinfer_sampling=disable_flashinfer_sampling,
)
return llm
@force_ampere # Save H100 resource
@pytest.mark.parametrize("reuse_cache", [False, True])
@pytest.mark.parametrize("return_log_probs", [False, True])
# FIXME: sometimes LLM shutdown hangs, might be related to https://nvbugs/5577178
# NB: Timeout covers fixtures https://github.com/pytest-dev/pytest-timeout/issues/134
@pytest.mark.timeout(120, method="signal")
@pytest.mark.threadleak(enabled=False)
def test_generate_with_return_logits(
llm,
gather_context_logits_fixture: bool,
gather_generation_logits_fixture: bool,
reuse_cache: bool,
return_log_probs: bool,
):
gather_context_logits = gather_context_logits_fixture
gather_generation_logits = gather_generation_logits_fixture
if not (gather_context_logits or gather_generation_logits or return_log_probs): # prune space
pytest.skip("Nothing to test")
sampling_params = SamplingParams(
max_tokens=8,
return_context_logits=gather_context_logits,
return_generation_logits=gather_generation_logits,
logprobs=return_log_probs,
)
for output in llm.generate(
prompts,
sampling_params=sampling_params,
cache_salt=[CacheSalter.get_salt(reuse_cache) for _ in prompts],
):
if gather_context_logits:
assert output.context_logits is not None
# NOTE: prompt_token_ids of "A B C" becomes [1, 319, 350, 315]
expected_len = len(prompts[0].split()) + 1
try:
assert expected_len == output.context_logits.shape[0]
except AssertionError:
# FIXME: Remove this once the bug has been fixed
if gather_context_logits and reuse_cache:
pytest.xfail("Known bug: https://nvbugs/5577178")
raise
else:
assert output.context_logits is None
for sequence in output.outputs:
assert sequence.length == sampling_params.max_tokens
if gather_generation_logits:
gen_logits = sequence.generation_logits
assert gen_logits is not None
assert gen_logits.ndim == 2
assert gen_logits.shape[0] == sampling_params.max_tokens
assert torch.argmax(gen_logits, dim=1).tolist() == sequence.token_ids
else:
assert sequence.generation_logits is None
if return_log_probs:
assert len(sequence.logprobs) == sampling_params.max_tokens
else:
assert len(sequence.logprobs) == 0
@force_ampere # Save H100 resource
@pytest.mark.parametrize("reuse_cache", [False, True])
@pytest.mark.parametrize("return_log_probs", [False, True])
# FIXME: sometimes LLM shutdown hangs, might be related to https://nvbugs/5577178
# NB: Timeout covers fixtures https://github.com/pytest-dev/pytest-timeout/issues/134
@pytest.mark.timeout(120, method="signal")
@pytest.mark.threadleak(enabled=False)
def test_generate_async_with_return_logits(
llm,
gather_context_logits_fixture: bool,
gather_generation_logits_fixture: bool,
reuse_cache: bool,
return_log_probs: bool,
):
gather_context_logits = gather_context_logits_fixture
gather_generation_logits = gather_generation_logits_fixture
if not (gather_context_logits or gather_generation_logits or return_log_probs): # prune space
pytest.skip("Nothing to test")
sampling_params = SamplingParams(
max_tokens=8,
return_context_logits=gather_context_logits,
return_generation_logits=gather_generation_logits,
logprobs=return_log_probs,
)
for idx, output in enumerate(
llm.generate_async(
prompts[0],
sampling_params=sampling_params,
streaming=True,
cache_salt=CacheSalter.get_salt(reuse_cache),
)
):
if gather_context_logits:
assert output.context_logits is not None
# NOTE: prompt_token_ids of "A B C" becomes [1, 319, 350, 315]
expected_len = len(prompts[0].split()) + 1
try:
assert expected_len == output.context_logits.shape[0]
except AssertionError:
# FIXME: Remove this once the bug has been fixed
if gather_context_logits and reuse_cache:
pytest.xfail("Known bug: https://nvbugs/5577178")
raise
else:
assert output.context_logits is None
for sequence in output.outputs:
assert sequence.length == idx + 1
if gather_generation_logits:
gen_logits = sequence.generation_logits
assert gen_logits is not None
assert gen_logits.ndim == 2
assert gen_logits.shape[0] == 1
try:
assert torch.argmax(gen_logits, dim=1).tolist()[0] == sequence.token_ids[-1]
except AssertionError:
# FIXME: Remove xfail once the bug is fixed
pytest.xfail("Known bug: https://nvbugs/5573238")
else:
assert sequence.generation_logits is None
if return_log_probs:
assert len(sequence.logprobs) == idx + 1
else:
assert len(sequence.logprobs) == 0
@pytest.mark.parametrize("logprobs_k", [0, 1, 3], ids=["top_0", "top_1", "top_3"])
@pytest.mark.parametrize("logprobs_mode", ["raw", "processed"])
@pytest.mark.threadleak(enabled=False)
def test_sampled_token_always_in_logprobs(logprobs_k: int, logprobs_mode: str, simple_llm: LLM):
"""Two scenarios:
- logprobs=0: Returns only sampled token (1 element)
- logprobs=K (K>0): Returns top-K tokens + sampled token if not in top-K (up to K+1 elements)
"""
sampling_params = SamplingParams(
max_tokens=8,
temperature=0.7,
top_p=0.9,
logprobs=logprobs_k,
logprobs_mode=logprobs_mode,
)
for output in simple_llm.generate(["The future of AI is"], sampling_params=sampling_params):
print(f"\n{'=' * 80}")
print(f"Generated text: {output.outputs[0].text!r}")
print(f"Generated token IDs: {output.outputs[0].token_ids}")
logprobs = output.outputs[0].logprobs
token_ids = output.outputs[0].token_ids
assert len(logprobs) == sampling_params.max_tokens, (
f"Expected {sampling_params.max_tokens} logprob entries, got {len(logprobs)}"
)
for token_idx, (sampled_token_id, token_logprobs) in enumerate(zip(token_ids, logprobs)):
print(
f"\n Token {token_idx}: "
f"ID={sampled_token_id}, "
f"Text={simple_llm.tokenizer.decode([sampled_token_id])!r}"
)
assert sampled_token_id in token_logprobs, (
f"Token {token_idx}: Sampled token ID {sampled_token_id} not in logprobs dict: {token_logprobs.keys()}"
)
if logprobs_k == 0:
assert len(token_logprobs) == 1, (
f"Token {token_idx}: Expected 1 logprob (sampled only), got {len(token_logprobs)}"
)
else:
assert len(token_logprobs) <= logprobs_k + 1, (
f"Token {token_idx}: Expected at most {logprobs_k + 1} logprobs, got {len(token_logprobs)}"
)
assert len(token_logprobs) >= 1
sorted_tokens_by_prob = sorted(
token_logprobs.items(), key=lambda x: x[1].logprob, reverse=True
)
if logprobs_k > 0:
sampled_token_rank = token_logprobs[sampled_token_id].rank
sampled_in_topk = sampled_token_rank <= logprobs_k
if not sampled_in_topk:
assert sorted_tokens_by_prob[-1][0] == sampled_token_id, (
f"Token {token_idx}: Sampled token (ID={sampled_token_id}, rank={sampled_token_rank}) "
f"not in top-{logprobs_k}, should be last in sorted list, "
f"but last token is ID={sorted_tokens_by_prob[-1][0]}"
)
for rank_idx, (token_id, logprob_obj) in enumerate(sorted_tokens_by_prob, start=1):
token_text = simple_llm.tokenizer.decode([token_id])
is_sampled = "← SAMPLED" if token_id == sampled_token_id else ""
print(
f" • Token {token_id:5d} ({token_text:15s}): "
f"logprob={logprob_obj.logprob:8.4f}, "
f"rank={logprob_obj.rank} {is_sampled}"
)
if logprobs_k > 0 and sampled_in_topk:
assert logprob_obj.rank == rank_idx, (
f"Token {token_idx}: Token {token_id} rank mismatch. "
f"Expected rank {rank_idx} (by sorted position), got {logprob_obj.rank}"
)
print(f"{'=' * 80}\n")
@pytest.mark.parametrize("logprobs_k", [0, 2], ids=["top_0", "top_2"])
@pytest.mark.threadleak(enabled=False)
def test_logprobs_with_grouped_samplings_strategies(logprobs_k: int, simple_llm: LLM):
"""Test logprobs when requests are reordered by sampling strategy grouping"""
test_prompts = [
"The capital of France is",
"The future of AI is",
"Hello, my name is",
"Hello, my name is",
"Write a short story about a cat",
]
# Causes reordering: [0,1,2,3,4] → [0,2,3,1,4]
sampling_params_list = [
SamplingParams(
max_tokens=6,
temperature=0.8,
top_k=50,
logprobs=logprobs_k,
return_generation_logits=True,
),
SamplingParams(
max_tokens=6,
temperature=0.8,
top_p=0.9,
logprobs=logprobs_k,
return_generation_logits=True,
),
SamplingParams(
max_tokens=6,
temperature=0.8,
top_k=50,
logprobs=logprobs_k,
return_generation_logits=True,
),
SamplingParams(
max_tokens=6, temperature=0.8, top_k=50, logprobs=None, return_generation_logits=True
),
SamplingParams(
max_tokens=6,
temperature=0.8,
top_p=0.9,
logprobs=logprobs_k,
return_generation_logits=True,
),
]
outputs = list(simple_llm.generate(test_prompts, sampling_params=sampling_params_list))
for req_idx, output in enumerate(outputs):
generation_logits = output.outputs[0].generation_logits.to(device="cuda")
token_ids = output.outputs[0].token_ids
logprobs = output.outputs[0].logprobs
if sampling_params_list[req_idx].logprobs is None:
assert len(logprobs) == 0
continue
assert generation_logits is not None
assert len(logprobs) == len(token_ids), "Logprobs length mismatch"
# generation_logits might be shorter than token_ids
num_logits = len(generation_logits)
for token_idx, (sampled_token_id, token_logprobs_dict) in enumerate(
zip(token_ids[:num_logits], logprobs[:num_logits])
):
returned_logprob = token_logprobs_dict[sampled_token_id].logprob
logits_for_token = generation_logits[token_idx]
expected_logprobs = torch.nn.functional.log_softmax(logits_for_token, dim=-1).to(
device="cpu"
)
expected_logprob = expected_logprobs[sampled_token_id].item()
print(
f"Req {req_idx}, Token {token_idx}: returned={returned_logprob:.6f}, expected={expected_logprob:.6f}"
)
torch.testing.assert_close(returned_logprob, expected_logprob)
@pytest.mark.parametrize("logprobs_k", [0, 2], ids=["top_0", "top_2"])
@pytest.mark.threadleak(enabled=False)
def test_processed_logprobs_e2e(logprobs_k: int, simple_llm: LLM):
"""Test logprobs when requests are reordered by sampling strategy grouping"""
test_prompts = [
"The capital of France is",
"The future of AI is",
"Hello, my name is",
"Write a short story about a cat",
"Hello, my name is",
"Write a short story about a cat",
]
sampling_params_list = [
# greedy decoding
SamplingParams(
max_tokens=6,
temperature=0.0,
logprobs=logprobs_k,
return_generation_logits=True,
logprobs_mode="processed",
),
# temperature sampling
SamplingParams(
max_tokens=6,
temperature=0.8,
logprobs=logprobs_k,
return_generation_logits=True,
logprobs_mode="processed",
),
# top-p sampling
SamplingParams(
max_tokens=6,
temperature=0.8,
top_p=0.9,
logprobs=logprobs_k,
return_generation_logits=True,
logprobs_mode="processed",
),
# top-k sampling
SamplingParams(
max_tokens=6,
temperature=0.8,
top_k=50,
logprobs=logprobs_k,
return_generation_logits=True,
logprobs_mode="processed",
),
# top-p sampling 2
SamplingParams(
max_tokens=6,
temperature=0.8,
top_p=0.9,
logprobs=logprobs_k,
return_generation_logits=True,
logprobs_mode="processed",
),
# top-p and top-k sampling
SamplingParams(
max_tokens=6,
temperature=0.8,
top_p=0.9,
top_k=50,
logprobs=logprobs_k,
return_generation_logits=True,
logprobs_mode="processed",
),
]
outputs = list(simple_llm.generate(test_prompts, sampling_params=sampling_params_list))
for req_idx, output in enumerate(outputs):
generation_logits = output.outputs[0].generation_logits.to(device="cuda")
token_ids = output.outputs[0].token_ids
logprobs = output.outputs[0].logprobs
assert generation_logits is not None
assert len(logprobs) == len(token_ids), "Logprobs length mismatch"
# generation_logits might be shorter than token_ids
num_logits = len(generation_logits)
for token_idx, token_logprobs_dict in enumerate(logprobs[:num_logits]):
assert token_ids[token_idx] in token_logprobs_dict, "Sampled token not in logprobs"
logits_for_token = generation_logits[token_idx : token_idx + 1]
topk = sampling_params_list[req_idx].top_k
topp = sampling_params_list[req_idx].top_p
temperature = sampling_params_list[req_idx].temperature
if sampling_params_list[req_idx]._greedy_decoding:
probs = torch.zeros_like(logits_for_token)
probs[0, token_ids[token_idx]] = 1.0
else:
topk = topk if topk is not None else logits_for_token.shape[-1]
topp = topp if topp is not None else 1.0
temperature = temperature if temperature is not None else 1.0
# perform maksing top-k top-p
if simple_llm.args.disable_flashinfer_sampling:
_, probs = top_k_top_p_sampling_batch(
logits_for_token, top_k=topk, top_p=topp, temperature=temperature
)
else:
_, probs = _StrategyImpls.StrategyImplWithProbs._sample_with_probs(
logits_for_token,
group_logit_indices=None,
top_k=torch.tensor([topk], dtype=torch.int32, device="cuda"),
top_p=torch.tensor([topp], dtype=torch.float32, device="cuda"),
temperature=torch.tensor([temperature], dtype=torch.float32, device="cuda"),
generator=None,
)
if temperature != 0:
logits_for_token /= temperature
adjusted_logits_for_token = torch.where(probs != 0, logits_for_token, float("-inf"))[0]
expected_logprobs = torch.nn.functional.log_softmax(
adjusted_logits_for_token, dim=-1
).to(device="cpu")
for logprob_token, logprob_values in token_logprobs_dict.items():
expected_logprob = expected_logprobs[logprob_token].item()
returned_logprob = logprob_values.logprob
print(
f"Req {req_idx}, Token {token_idx}: "
f"returned={returned_logprob:.6f}, expected={expected_logprob:.6f}"
)
torch.testing.assert_close(returned_logprob, expected_logprob)
@force_ampere
@pytest.mark.gpu2
def test_logprobs_match_hf_tp2():
model_path = os.path.join(llm_models_root(), "llama-models-v2", "TinyLlama-1.1B-Chat-v1.0")
llm = LLM(
model=model_path,
tensor_parallel_size=2,
)
prompts = ["The future of the AI is"]
sampling_params = SamplingParams(
max_tokens=10,
temperature=1.0,
logprobs=0,
)
hf_model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16).to(
"cuda"
)
hf_tokenizer = AutoTokenizer.from_pretrained(model_path)
output = list(llm.generate(prompts, sampling_params=sampling_params))[0]
trtllm_token_ids = output.outputs[0].token_ids
trtllm_logprobs = torch.tensor(
[list(lp.values())[0].logprob for lp in output.outputs[0].logprobs]
)
base_ids = hf_tokenizer.encode(prompts[0], return_tensors="pt").to("cuda")
hf_logprobs = []
for i, token_id in enumerate(trtllm_token_ids):
if i > 0:
prev_tokens = torch.tensor([trtllm_token_ids[:i]], device="cuda")
input_ids = torch.cat([base_ids, prev_tokens], dim=1)
else:
input_ids = base_ids
with torch.no_grad():
logits = hf_model(input_ids).logits[0, -1, :]
hf_logprobs.append(torch.log_softmax(logits, dim=-1)[token_id].item())
hf_logprobs = torch.tensor(hf_logprobs)
print(f"\nTensorRT-LLM logprobs: {trtllm_logprobs}")
print(f"HuggingFace logprobs: {hf_logprobs}")
print(f"Diff: {(trtllm_logprobs - hf_logprobs).abs()}")
torch.testing.assert_close(trtllm_logprobs, hf_logprobs, atol=0.15, rtol=0)

View File

@ -1,239 +0,0 @@
import os
import pytest
import torch
from utils.llm_data import llm_models_root
from utils.util import force_ampere
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi.llm_utils import KvCacheConfig
prompts = ["A B C"]
global_kvcache_config = KvCacheConfig(
max_tokens=10000,
enable_block_reuse=True,
)
@pytest.fixture(scope="module", params=[False, True])
def gather_generation_logits_fixture(request) -> bool:
return request.param
@pytest.fixture(scope="module", params=[False, True])
def gather_context_logits_fixture(request) -> bool:
return request.param
@pytest.fixture(scope="module", params=[False, True])
def disable_overlap_scheduler_fixture(request) -> bool:
return request.param
@pytest.fixture(scope="module", params=["TRTLLMSampler", "TorchSampler"])
def sampler_type_fixture(request) -> str:
return request.param
class CacheSalter:
_salt = 0
@classmethod
def get_salt_unique(cls) -> str:
cls._salt += 1
return str(cls._salt)
@classmethod
def get_salt_shared(cls) -> str:
return str(0)
@classmethod
def get_salt(cls, reuse_cache: bool) -> str:
if reuse_cache:
salt = cls.get_salt_shared()
else:
salt = cls.get_salt_unique()
return salt
@pytest.fixture(scope="module")
def llm(
gather_context_logits_fixture: bool,
gather_generation_logits_fixture: bool,
sampler_type_fixture: str,
disable_overlap_scheduler_fixture: bool,
):
gather_generation_logits = gather_generation_logits_fixture
sampler_type = sampler_type_fixture
disable_overlap_scheduler = disable_overlap_scheduler_fixture
llm = LLM(
model=os.path.join(llm_models_root(), "llama-models-v2",
"TinyLlama-1.1B-Chat-v1.0"),
kv_cache_config=global_kvcache_config,
gather_generation_logits=gather_generation_logits,
max_batch_size=
128, # reduce buffer sizes, specially for generation logits
sampler_type=sampler_type,
disable_overlap_scheduler=disable_overlap_scheduler,
)
# FIXME: Sometimes LLM shutdown hangs, might be related to https://nvbugs/5577178.
# Remove patch below once fixed.
old_exit = LLM.__exit__
def _exit_with_xfail_on_timeout(self, exc_type, exc_value,
traceback) -> bool:
import _pytest.outcomes
try:
return old_exit(self, exc_type, exc_value, traceback)
except _pytest.outcomes.Failed as e:
if e.msg and "pytest-timeout" in e.msg.lower():
pytest.xfail(
"Known LLM shutdown issue (https://nvbugs/5577178).")
else:
raise
with pytest.MonkeyPatch.context() as patch:
patch.setattr(LLM, "__exit__", _exit_with_xfail_on_timeout)
with llm:
yield llm
@force_ampere # Save H100 resource
@pytest.mark.parametrize("reuse_cache", [False, True])
@pytest.mark.parametrize("return_log_probs", [False, True])
# FIXME: sometimes LLM shutdown hangs, might be related to https://nvbugs/5577178
# NB: Timeout covers fixtures https://github.com/pytest-dev/pytest-timeout/issues/134
@pytest.mark.timeout(120, method="signal")
@pytest.mark.threadleak(enabled=False)
def test_generate_with_return_logits(
llm,
gather_context_logits_fixture: bool,
gather_generation_logits_fixture: bool,
reuse_cache: bool,
return_log_probs: bool,
):
gather_context_logits = gather_context_logits_fixture
gather_generation_logits = gather_generation_logits_fixture
if not (gather_context_logits or gather_generation_logits
or return_log_probs): # prune space
pytest.skip("Nothing to test")
sampling_params = SamplingParams(
max_tokens=8,
return_context_logits=gather_context_logits,
return_generation_logits=gather_generation_logits,
logprobs=return_log_probs,
)
for output in llm.generate(
prompts,
sampling_params=sampling_params,
cache_salt=[CacheSalter.get_salt(reuse_cache) for _ in prompts],
):
if gather_context_logits:
assert output.context_logits is not None
# NOTE: prompt_token_ids of "A B C" becomes [1, 319, 350, 315]
expected_len = len(prompts[0].split()) + 1
try:
assert expected_len == output.context_logits.shape[0]
except AssertionError:
# FIXME: Remove this once the bug has been fixed
if gather_context_logits and reuse_cache:
pytest.xfail("Known bug: https://nvbugs/5577178")
raise
else:
assert output.context_logits is None
for sequence in output.outputs:
assert sequence.length == sampling_params.max_tokens
if gather_generation_logits:
gen_logits = sequence.generation_logits
assert gen_logits is not None
assert gen_logits.ndim == 2
assert gen_logits.shape[0] == sampling_params.max_tokens
assert torch.argmax(gen_logits,
dim=1).tolist() == sequence.token_ids
else:
assert sequence.generation_logits is None
if return_log_probs:
assert len(sequence.logprobs) == sampling_params.max_tokens
else:
assert len(sequence.logprobs) == 0
@force_ampere # Save H100 resource
@pytest.mark.parametrize("reuse_cache", [False, True])
@pytest.mark.parametrize("return_log_probs", [False, True])
# FIXME: sometimes LLM shutdown hangs, might be related to https://nvbugs/5577178
# NB: Timeout covers fixtures https://github.com/pytest-dev/pytest-timeout/issues/134
@pytest.mark.timeout(120, method="signal")
@pytest.mark.threadleak(enabled=False)
def test_generate_async_with_return_logits(
llm,
gather_context_logits_fixture: bool,
gather_generation_logits_fixture: bool,
reuse_cache: bool,
return_log_probs: bool,
):
gather_context_logits = gather_context_logits_fixture
gather_generation_logits = gather_generation_logits_fixture
if not (gather_context_logits or gather_generation_logits
or return_log_probs): # prune space
pytest.skip("Nothing to test")
sampling_params = SamplingParams(
max_tokens=8,
return_context_logits=gather_context_logits,
return_generation_logits=gather_generation_logits,
logprobs=return_log_probs)
for idx, output in enumerate(
llm.generate_async(
prompts[0],
sampling_params=sampling_params,
streaming=True,
cache_salt=CacheSalter.get_salt(reuse_cache),
)):
if gather_context_logits:
assert output.context_logits is not None
# NOTE: prompt_token_ids of "A B C" becomes [1, 319, 350, 315]
expected_len = len(prompts[0].split()) + 1
try:
assert expected_len == output.context_logits.shape[0]
except AssertionError:
# FIXME: Remove this once the bug has been fixed
if gather_context_logits and reuse_cache:
pytest.xfail("Known bug: https://nvbugs/5577178")
raise
else:
assert output.context_logits is None
for sequence in output.outputs:
assert sequence.length == idx + 1
if gather_generation_logits:
gen_logits = sequence.generation_logits
assert gen_logits is not None
assert gen_logits.ndim == 2
assert gen_logits.shape[0] == 1
try:
assert torch.argmax(
gen_logits, dim=1).tolist()[0] == sequence.token_ids[-1]
except AssertionError:
# FIXME: Remove xfail once the bug is fixed
pytest.xfail("Known bug: https://nvbugs/5573238")
else:
assert sequence.generation_logits is None
if return_log_probs:
assert len(sequence.logprobs) == idx + 1
else:
assert len(sequence.logprobs) == 0

View File

@ -86,7 +86,7 @@ class TestStrategySelection:
is_context_init_state: bool # Torch sampler accesses this, but it does not affect this test
def get_beam_width_by_iter(
self, for_next_iteration: bool
self, for_next_iteration: bool = False
) -> int: # Torch sampler accesses this, but it does not affect this test
return self.sampling_config.beam_width
@ -445,12 +445,22 @@ def test_select_generated_logits(draft_len: int, with_ctx: bool, with_gen: bool)
def py_return_context_logits(self) -> bool:
return self._return_context_logits
def get_beam_width_by_iter(
self, for_next_iteration: bool = False
) -> int: # Torch sampler accesses this, but it does not affect this test
return self.sampling_config.beam_width
class GenRequestMock:
def __init__(self, draft_len: int):
self.is_context_init_state = False
self.py_draft_tokens = torch.empty(draft_len, dtype=torch.int32, device=device)
self.sampling_config = SamplingConfig(beam_width=1)
def get_beam_width_by_iter(
self, for_next_iteration: bool = False
) -> int: # Torch sampler accesses this, but it does not affect this test
return self.sampling_config.beam_width
class ScheduledRequestsMock:
@property
def context_requests(self) -> list[LlmRequest]:

View File

@ -31,6 +31,7 @@ from tensorrt_llm.llmapi import (CalibConfig, CompletionOutput,
from tensorrt_llm.llmapi.llm_args import SamplerType
from tensorrt_llm.llmapi.llm_utils import LlmArgs
from tensorrt_llm.logger import Singleton
from tensorrt_llm.sampling_params import LogprobMode
def repr_annotation(field_type: type) -> str:

View File

@ -15,5 +15,8 @@ methods:
prompt_ignore_length:
annotation: Optional[int]
default: null
logprobs_mode:
annotation: LogprobMode
default: LogprobMode.RAW
return_annotation: None
properties: {}