mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-03 17:52:19 +08:00
[TRTLLM-9735][feat] Add processed logprobs functionality to TorchSampler (#9675)
Signed-off-by: Stefan Niebler <82932102+stnie@users.noreply.github.com> Signed-off-by: Yuan Tong <13075180+tongyuantongyu@users.noreply.github.com> Signed-off-by: Erin Ho <14718778+hchings@users.noreply.github.com> Co-authored-by: Yuan Tong <13075180+tongyuantongyu@users.noreply.github.com> Co-authored-by: Erin Ho <14718778+hchings@users.noreply.github.com>
This commit is contained in:
parent
cfebfbb505
commit
0cfd08745c
@ -1020,7 +1020,6 @@ common-files: &common_files |
|
||||
tests/unittest/_torch/ray_orchestrator/single_gpu/test_cache_transceiver_comm.py |
|
||||
tests/unittest/_torch/sampler/test_beam_search.py |
|
||||
tests/unittest/_torch/sampler/test_best_of_n.py |
|
||||
tests/unittest/_torch/sampler/test_return_logits.py |
|
||||
tests/unittest/_torch/sampler/test_torch_multi_arange.py |
|
||||
tests/unittest/_torch/sampler/test_trtllm_sampler.py |
|
||||
tests/unittest/_torch/speculative/test_draft_target.py |
|
||||
|
||||
@ -1061,7 +1061,6 @@ exclude = [
|
||||
"tests/unittest/_torch/ray_orchestrator/single_gpu/test_cache_transceiver_comm.py",
|
||||
"tests/unittest/_torch/sampler/test_beam_search.py",
|
||||
"tests/unittest/_torch/sampler/test_best_of_n.py",
|
||||
"tests/unittest/_torch/sampler/test_return_logits.py",
|
||||
"tests/unittest/_torch/sampler/test_torch_multi_arange.py",
|
||||
"tests/unittest/_torch/sampler/test_trtllm_sampler.py",
|
||||
"tests/unittest/_torch/speculative/test_draft_target.py",
|
||||
|
||||
@ -8,6 +8,7 @@ import tensorrt_llm.bindings
|
||||
from tensorrt_llm._torch.shared_tensor import SharedTensorContainer
|
||||
from tensorrt_llm.bindings import executor as tllm_executor
|
||||
from tensorrt_llm.executor.result import TokenLogprobs
|
||||
from tensorrt_llm.sampling_params import LogprobMode
|
||||
|
||||
SamplingConfig = tensorrt_llm.bindings.SamplingConfig
|
||||
'''
|
||||
@ -485,6 +486,7 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
|
||||
is_first_draft: bool = False,
|
||||
use_chunked_generation_logits: bool = True,
|
||||
logits_chunk_size: int = 8,
|
||||
logprobs_mode: LogprobMode = LogprobMode.RAW,
|
||||
**kwargs):
|
||||
|
||||
self.py_logits_post_processors = kwargs.pop("py_logits_post_processors",
|
||||
@ -566,6 +568,9 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
|
||||
# currently, keep py_stop_words_list as python list, rather than tensor.
|
||||
self.py_stop_words_list = stop_words_list
|
||||
|
||||
self.py_logprobs_mode = LogprobMode(
|
||||
logprobs_mode) # handle passed a raw string
|
||||
|
||||
self.py_result = PyResult(
|
||||
prompt_len=self.py_prompt_len,
|
||||
max_new_tokens=self.py_max_new_tokens,
|
||||
@ -825,7 +830,10 @@ def executor_request_to_llm_request(
|
||||
arrival_time=getattr(executor_request, "py_arrival_time", None),
|
||||
py_multimodal_data=getattr(executor_request, "py_multimodal_data",
|
||||
None),
|
||||
kv_cache_retention_config=executor_request.kv_cache_retention_config)
|
||||
kv_cache_retention_config=executor_request.kv_cache_retention_config,
|
||||
logprobs_mode=getattr(executor_request, "py_logprobs_mode",
|
||||
LogprobMode.RAW),
|
||||
)
|
||||
if child_req_ids:
|
||||
for child_id in child_req_ids:
|
||||
llm_request.create_child_request(child_id)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -266,7 +266,8 @@ def greedy_search_sampling_batch(
|
||||
next_tokens = torch.argmax(logits, dim=-1)
|
||||
softmax: Optional[torch.Tensor] = None
|
||||
if return_probs:
|
||||
softmax = torch.softmax(logits, dim=-1)
|
||||
softmax = torch.zeros_like(logits)
|
||||
softmax.scatter_(1, next_tokens.unsqueeze(-1), 1.0)
|
||||
return next_tokens, softmax
|
||||
|
||||
|
||||
@ -471,10 +472,10 @@ def sample(
|
||||
strategy: Strategy,
|
||||
logits: torch.Tensor,
|
||||
*,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator: torch.Generator | None = None,
|
||||
group_metadata: StrategyMetadata | None = None,
|
||||
return_probs: bool = True,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None, float | None]:
|
||||
match strategy:
|
||||
case ("top_k", top_k, temperature):
|
||||
tokens, softmax = top_k_sampling_batch(
|
||||
@ -506,6 +507,7 @@ def sample(
|
||||
)
|
||||
case ("greedy", None):
|
||||
tokens, softmax = greedy_search_sampling_batch(logits, return_probs=return_probs)
|
||||
temperature = None
|
||||
case ("beam_search", beam_width_in, beam_width_out, temperature):
|
||||
assert group_metadata is not None and isinstance(group_metadata, BeamSearchMetadata), (
|
||||
"BeamSearchMetadata is required for beam_search_sampling_batch"
|
||||
@ -519,7 +521,7 @@ def sample(
|
||||
generator=generator,
|
||||
return_probs=return_probs,
|
||||
)
|
||||
return tokens, softmax
|
||||
return tokens, softmax, temperature
|
||||
|
||||
|
||||
GenericStrategyKeyType = TypeVar("GenericStrategyKeyType")
|
||||
@ -545,11 +547,11 @@ class GroupedStrategySampler(Generic[GenericStrategyKeyType], abc.ABC):
|
||||
strategies: list[Strategy],
|
||||
logits: torch.Tensor,
|
||||
*,
|
||||
group_logit_indices: Optional[torch.Tensor] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
group_logit_indices: torch.Tensor | None = None,
|
||||
generator: torch.Generator | None = None,
|
||||
return_probs: bool,
|
||||
group_metadata: StrategyMetadata | None = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None, float | torch.Tensor | None]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -579,11 +581,11 @@ class SimpleGroupedStrategySampler(GroupedStrategySampler[Strategy]):
|
||||
strategies: list[Strategy],
|
||||
logits: torch.Tensor,
|
||||
*,
|
||||
group_logit_indices: Optional[torch.Tensor] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
group_logit_indices: torch.Tensor | None = None,
|
||||
generator: torch.Generator | None = None,
|
||||
return_probs: bool,
|
||||
group_metadata: StrategyMetadata | None = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None, float | None]:
|
||||
if group_key[0] == "beam_search":
|
||||
beam_width_in = group_key[1]
|
||||
else:
|
||||
|
||||
@ -141,8 +141,9 @@ class _StrategyImpls:
|
||||
*,
|
||||
group_logit_indices: Optional[torch.Tensor],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
probs = self._prepare_probs_with_temperature(logits, group_logit_indices, None)
|
||||
new_tokens, _ = greedy_search_sampling_batch(probs, return_probs=False)
|
||||
if group_logit_indices is not None:
|
||||
logits = torch.index_select(logits, 0, group_logit_indices) # ensures copy
|
||||
new_tokens, probs = greedy_search_sampling_batch(logits, return_probs=True)
|
||||
return new_tokens, probs
|
||||
|
||||
@classmethod
|
||||
@ -240,6 +241,9 @@ class _StrategyImpls:
|
||||
return True
|
||||
|
||||
class GreedyWithProbs(StrategyImplWithProbs):
|
||||
def __init__(self):
|
||||
self._temperature = None
|
||||
|
||||
@override
|
||||
@classmethod
|
||||
def from_strategies(
|
||||
@ -425,6 +429,9 @@ class _StrategyImpls:
|
||||
return False
|
||||
|
||||
class GreedySampleOnly(StrategyImplSampleOnly):
|
||||
def __init__(self):
|
||||
self._temperature = None
|
||||
|
||||
@override
|
||||
@classmethod
|
||||
def from_strategies(
|
||||
@ -722,7 +729,7 @@ class FlashInferGroupedStrategySampler(GroupedStrategySampler[Type[_StrategyImpl
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_probs: bool,
|
||||
group_metadata: StrategyMetadata | None = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
if hasattr(group_key, "static_beam_width_in"):
|
||||
beam_width_in = group_key.static_beam_width_in
|
||||
else:
|
||||
@ -735,9 +742,16 @@ class FlashInferGroupedStrategySampler(GroupedStrategySampler[Type[_StrategyImpl
|
||||
assert return_probs == group_key.computes_probs()
|
||||
|
||||
strategy_impl_cls = group_key
|
||||
return strategy_impl_cls.from_strategies(strategies, cuda_device=logits.device).sample(
|
||||
sampling_object = strategy_impl_cls.from_strategies(strategies, cuda_device=logits.device)
|
||||
next_tokens, softmax = sampling_object.sample(
|
||||
logits,
|
||||
group_logit_indices=group_logit_indices,
|
||||
generator=generator,
|
||||
group_metadata=group_metadata,
|
||||
)
|
||||
temperature = (
|
||||
sampling_object._temperature.unsqueeze(-1)
|
||||
if sampling_object._temperature is not None
|
||||
else None
|
||||
)
|
||||
return next_tokens, softmax, temperature
|
||||
|
||||
@ -215,7 +215,7 @@ class MTPSampler(TorchSampler):
|
||||
|
||||
SampleState = SampleStateMTP
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
@dataclass(kw_only=True)
|
||||
class Store(TorchSampler.Store):
|
||||
new_tokens: torch.Tensor
|
||||
next_new_tokens: torch.Tensor
|
||||
|
||||
@ -562,6 +562,7 @@ class BaseWorker(GenerationExecutor):
|
||||
cache_salt_id=request.cache_salt_id)
|
||||
executor_request.py_num_logprobs = request.sampling_params.logprobs
|
||||
executor_request.py_lora_path = py_lora_path
|
||||
executor_request.py_logprobs_mode = request.sampling_params.logprobs_mode
|
||||
|
||||
if self._is_pytorch_backend and request.multimodal_params is not None:
|
||||
if request.multimodal_params.multimodal_data is not None:
|
||||
|
||||
@ -221,7 +221,7 @@ class GenerationExecutor(ABC):
|
||||
self, request: GenerationRequest) -> Optional[LogprobParams]:
|
||||
"""Store logprobs-related fields from request for the later logprob calculation."""
|
||||
logprob_params = None
|
||||
if request.sampling_params.logprobs or request.sampling_params.prompt_logprobs:
|
||||
if request.sampling_params.logprobs is not None or request.sampling_params.prompt_logprobs:
|
||||
logprob_params = LogprobParams(
|
||||
logprobs=request.sampling_params.logprobs,
|
||||
prompt_logprobs=request.sampling_params.prompt_logprobs,
|
||||
|
||||
@ -933,6 +933,21 @@ def compute_logprobs(
|
||||
logits = logits[:len(tokens)]
|
||||
|
||||
logprobs = F.log_softmax(logits.to("cuda", dtype=torch.float32), dim=-1)
|
||||
|
||||
# only return sampled token
|
||||
if top_k == 0:
|
||||
results: TokenLogprobs = []
|
||||
if tokens is not None:
|
||||
for t in range(logprobs.size(0)):
|
||||
token_id = tokens[t]
|
||||
token_logprob = logprobs[t, token_id].item()
|
||||
rank = (logprobs[t] > token_logprob).sum().item() + 1
|
||||
token_dict = {
|
||||
token_id: Logprob(logprob=token_logprob, rank=rank)
|
||||
}
|
||||
results.append(token_dict)
|
||||
return results
|
||||
|
||||
topk_vals, topk_indices = torch.topk(logprobs, k=top_k, dim=-1)
|
||||
|
||||
results: TokenLogprobs = []
|
||||
@ -961,7 +976,7 @@ def compute_logprobs(
|
||||
None) if k_prompt_logprobs and context_logits is not None else None
|
||||
generation_logprobs = _topk_logprobs(
|
||||
generation_logits, k_logprobs, output_token_ids
|
||||
) if k_logprobs and generation_logits is not None else None
|
||||
) if k_logprobs is not None and generation_logits is not None else None
|
||||
|
||||
return LogProbsResult(prompt=prompt_logprobs,
|
||||
generation=generation_logprobs)
|
||||
|
||||
@ -666,7 +666,7 @@ class BaseLLM:
|
||||
if sampling_params.prompt_logprobs and not sampling_params.return_context_logits:
|
||||
sampling_params.return_context_logits = True
|
||||
sampling_params._context_logits_auto_enabled = True
|
||||
if sampling_params.logprobs and not sampling_params.return_generation_logits:
|
||||
if sampling_params.logprobs is not None and not sampling_params.return_generation_logits:
|
||||
sampling_params.return_generation_logits = True
|
||||
sampling_params._generation_logits_auto_enabled = True
|
||||
|
||||
@ -737,7 +737,7 @@ class BaseLLM:
|
||||
f"Example: LLM(..., build_config=BuildConfig(gather_context_logits=True))."
|
||||
)
|
||||
|
||||
if sampling_params.logprobs and not self.args.gather_generation_logits:
|
||||
if sampling_params.logprobs is not None and not self.args.gather_generation_logits:
|
||||
raise ValueError(
|
||||
f"`sampling_params.logprobs={sampling_params.logprobs}` requires `gather_generation_logits=True` "
|
||||
f"to be passed explicitly to the `LLM()` constructor.")
|
||||
|
||||
@ -6,6 +6,7 @@ from typing import List, NamedTuple, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
from strenum import StrEnum
|
||||
|
||||
from tensorrt_llm.bindings import executor as tllme
|
||||
from tensorrt_llm.logger import logger
|
||||
@ -46,6 +47,18 @@ class LogprobParams(NamedTuple):
|
||||
drop_generation_logits: bool = False
|
||||
|
||||
|
||||
class LogprobMode(StrEnum):
|
||||
RAW = "raw"
|
||||
"""
|
||||
Return the raw log probabilities, i.e., the log probabilities calculated directly from the model output logits.
|
||||
"""
|
||||
PROCESSED = "processed"
|
||||
"""
|
||||
Return the processed log probabilities, i.e., the log probabilities after applying sampling parameters,
|
||||
such as temperature, top-k, top-p, etc.
|
||||
"""
|
||||
|
||||
|
||||
class LogitsProcessor(ABC):
|
||||
"""Base class for logits processor.
|
||||
|
||||
@ -172,7 +185,9 @@ class SamplingParams:
|
||||
min_p (float, optional): scale the most likely token to determine the minimum token probability. None means using C++ runtime default 0.0. Defaults to None.
|
||||
beam_width_array (List[int], optional): The array of beam width using in Variable-Beam-Width-Search. Defaults to None.
|
||||
|
||||
logprobs (int, optional): Number of log probabilities to return per output token. Defaults to None.
|
||||
logprobs (int, optional): Number of log probabilities to return per output token. When set to 0, return only the sampled token's log probability.
|
||||
When set to K>0, return top-K log probabilities + the sampled token's log probability (last entry) if it's not in the Top-K. Defaults to None.
|
||||
logprobs_mode (LogprobMode): The mode of log probabilities to return. Defaults to LogprobMode.RAW.
|
||||
prompt_logprobs (int, optional): Number of log probabilities to return per prompt token. Defaults to None.
|
||||
return_context_logits (bool): Controls if Result should contain the context logits. Defaults to False.
|
||||
return_generation_logits (bool): Controls if Result should contain the generation logits. Defaults to False.
|
||||
@ -219,6 +234,7 @@ class SamplingParams:
|
||||
n: int = 1
|
||||
best_of: Optional[int] = None
|
||||
use_beam_search: bool = False
|
||||
logprobs_mode: LogprobMode = LogprobMode.RAW
|
||||
|
||||
# Keep the below fields in sync with tllme.SamplingConfig or maintin the mapping table.
|
||||
top_k: Optional[int] = None
|
||||
@ -321,6 +337,8 @@ class SamplingParams:
|
||||
f"under the greedy decoding."
|
||||
)
|
||||
|
||||
self.logprobs_mode = LogprobMode(self.logprobs_mode)
|
||||
|
||||
if self.truncate_prompt_tokens is not None and self.truncate_prompt_tokens < 1:
|
||||
raise ValueError(
|
||||
f"truncate_prompt_tokens must be >= 1, got {self.truncate_prompt_tokens}"
|
||||
@ -329,8 +347,11 @@ class SamplingParams:
|
||||
if self.guided_decoding is not None:
|
||||
self.guided_decoding._validate()
|
||||
|
||||
# correct types as users might pass in logprob=True for Top-1 logprobs
|
||||
self.logprobs = self.logprobs and int(self.logprobs)
|
||||
# correct types as users might pass in logprob=True for Top-0 logprobs and logprobs=False for no logprobs
|
||||
if self.logprobs is False:
|
||||
self.logprobs = None
|
||||
if self.logprobs is True:
|
||||
self.logprobs = 0
|
||||
self.prompt_logprobs = self.prompt_logprobs and int(self.prompt_logprobs)
|
||||
|
||||
# NB: Static, because downstream code only holds instances of
|
||||
@ -494,7 +515,7 @@ class SamplingParams:
|
||||
config_kwargs = {f: getattr(self, f) for f in fields}
|
||||
|
||||
if is_pytorch_backend:
|
||||
config_kwargs["return_log_probs"] = bool(self.logprobs)
|
||||
config_kwargs["return_log_probs"] = self.logprobs is not None
|
||||
if self.prompt_logprobs and not self.return_context_logits:
|
||||
logger.info(
|
||||
"Since prompt_logprobs is requested but return_context_logits is False, "
|
||||
|
||||
@ -21,7 +21,7 @@ l0_a30:
|
||||
- unittest/_torch/modeling -k "modeling_out_of_tree"
|
||||
- unittest/_torch/modeling -k "modeling_starcoder2"
|
||||
- unittest/_torch/sampler/test_beam_search.py
|
||||
- unittest/_torch/sampler/test_return_logits.py
|
||||
- unittest/_torch/sampler/test_logits_logprobs.py
|
||||
- test_e2e.py::test_openai_completions_with_logit_bias[torch_sampler]
|
||||
- test_e2e.py::test_openai_chat_with_logit_bias[torch_sampler]
|
||||
- test_e2e.py::test_openai_completions_with_logit_bias[trtllm_sampler]
|
||||
|
||||
@ -284,6 +284,8 @@ test_e2e.py::test_ptp_quickstart_advanced_2gpus_sm120[Nemotron-Super-49B-v1-BF16
|
||||
triton_server/test_triton.py::test_gpt_speculative_decoding[gpt-speculative-decoding] SKIP (https://nvbugs/5762854)
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B_Instruct_RocketKV::test_auto_dtype SKIP (https://nvbugs/5762822)
|
||||
unittest/_torch/sampler/test_return_logits.py SKIP (https://nvbugs/5764627)
|
||||
unittest/_torch/sampler/test_logits_logprobs.py::test_generate_with_return_logits SKIP (https://nvbugs/5764627)
|
||||
unittest/_torch/sampler/test_logits_logprobs.py::test_generate_async_with_return_logits SKIP (https://nvbugs/5764627)
|
||||
examples/serve/test_serve.py::test_config_file_loading[--config] SKIP (https://nvbugs/5754977)
|
||||
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp2pp2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugspro.nvidia.com/bug/5794313)
|
||||
examples/test_ray.py::test_ray_disaggregated_serving[tp2] SKIP (https://nvbugs/5612502)
|
||||
|
||||
@ -125,7 +125,7 @@ def check_generation_logits(beam: CompletionOutput,
|
||||
def check_logprobs(beam: CompletionOutput, sampling_params: SamplingParams,
|
||||
valid_tokens: int | None) -> None:
|
||||
"""Check if the logprobs have the correct shape"""
|
||||
if sampling_params.logprobs:
|
||||
if sampling_params.logprobs is not None:
|
||||
generated_tokens = valid_tokens if valid_tokens is not None else sampling_params.max_tokens
|
||||
assert len(
|
||||
beam.logprobs
|
||||
@ -345,7 +345,7 @@ class GeneralTestParams:
|
||||
prompt_len = len(input_tokens)
|
||||
num_generated_tokens = 5
|
||||
seq_len = prompt_len + num_generated_tokens
|
||||
num_logprobs = 1
|
||||
num_logprobs = 0
|
||||
seq_slot = 4
|
||||
end_id = 99
|
||||
batch_size = 2
|
||||
@ -541,7 +541,7 @@ def create_default_request(test_params: GeneralTestParams) -> LlmRequest:
|
||||
end_id=test_params.end_id,
|
||||
sampling_config=SamplingConfig(
|
||||
sampling_params._get_sampling_config()),
|
||||
return_log_probs=test_params.num_logprobs > 0,
|
||||
return_log_probs=test_params.num_logprobs >= 0,
|
||||
num_logprobs=test_params.num_logprobs,
|
||||
is_streaming=False)
|
||||
|
||||
@ -590,7 +590,7 @@ def test_create_beam_history():
|
||||
num_generated_tokens = test_params.num_generated_tokens
|
||||
seq_slot = test_params.seq_slot
|
||||
vocab_size = test_params.vocab_size
|
||||
num_logprobs = test_params.num_logprobs
|
||||
num_logprobs = test_params.num_logprobs + 1
|
||||
cache_indirection = sampler.store.cache_indirection
|
||||
original_tokens = sampler.store.original_tokens
|
||||
original_logprobs = torch.zeros(
|
||||
@ -635,7 +635,11 @@ def test_create_beam_history():
|
||||
# set the logprobs in the request:
|
||||
token_logprobs = sampler._convert_logprobs_tensor_to_list(
|
||||
original_logprob_indices[:beam_width, :num_generated_tokens - 1],
|
||||
original_logprobs[:beam_width, :num_generated_tokens - 1])
|
||||
original_logprobs[:beam_width, :num_generated_tokens - 1],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
request.py_result.set_log_probs(
|
||||
token_logprobs,
|
||||
cum_log_probs=torch.zeros_like(
|
||||
@ -657,9 +661,10 @@ def test_create_beam_history():
|
||||
) > 0, "Deterministic offsets must not only contain zeros. Otherwise change the seed."
|
||||
|
||||
# set the new log probs and tokens for the beam search sampling
|
||||
sampler.store.new_log_probs[
|
||||
sampler.store.sampled_log_probs[
|
||||
seq_slot, :beam_width] = original_logprobs[:beam_width,
|
||||
num_generated_tokens - 1, 0]
|
||||
num_generated_tokens - 1,
|
||||
0:1]
|
||||
sampler.store.new_tokens[
|
||||
0,
|
||||
seq_slot, :beam_width] = original_logprob_indices[:beam_width,
|
||||
|
||||
589
tests/unittest/_torch/sampler/test_logits_logprobs.py
Normal file
589
tests/unittest/_torch/sampler/test_logits_logprobs.py
Normal file
@ -0,0 +1,589 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from utils.llm_data import llm_models_root
|
||||
from utils.util import force_ampere
|
||||
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
from tensorrt_llm._torch.pyexecutor.sampling_utils import top_k_top_p_sampling_batch
|
||||
from tensorrt_llm._torch.pyexecutor.sampling_utils_flashinfer import _StrategyImpls
|
||||
from tensorrt_llm.llmapi.llm_utils import KvCacheConfig
|
||||
|
||||
prompts = ["A B C"]
|
||||
global_kvcache_config = KvCacheConfig(
|
||||
max_tokens=10000,
|
||||
enable_block_reuse=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=[False, True])
|
||||
def gather_generation_logits_fixture(request) -> bool:
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=[False, True])
|
||||
def gather_context_logits_fixture(request) -> bool:
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=[False, True])
|
||||
def disable_overlap_scheduler_fixture(request) -> bool:
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=["TRTLLMSampler", "TorchSampler"])
|
||||
def sampler_type_fixture(request) -> str:
|
||||
return request.param
|
||||
|
||||
|
||||
class CacheSalter:
|
||||
_salt = 0
|
||||
|
||||
@classmethod
|
||||
def get_salt_unique(cls) -> str:
|
||||
cls._salt += 1
|
||||
return str(cls._salt)
|
||||
|
||||
@classmethod
|
||||
def get_salt_shared(cls) -> str:
|
||||
return str(0)
|
||||
|
||||
@classmethod
|
||||
def get_salt(cls, reuse_cache: bool) -> str:
|
||||
if reuse_cache:
|
||||
salt = cls.get_salt_shared()
|
||||
else:
|
||||
salt = cls.get_salt_unique()
|
||||
return salt
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def llm(
|
||||
gather_context_logits_fixture: bool,
|
||||
gather_generation_logits_fixture: bool,
|
||||
sampler_type_fixture: str,
|
||||
disable_overlap_scheduler_fixture: bool,
|
||||
):
|
||||
gather_generation_logits = gather_generation_logits_fixture
|
||||
sampler_type = sampler_type_fixture
|
||||
disable_overlap_scheduler = disable_overlap_scheduler_fixture
|
||||
|
||||
llm = LLM(
|
||||
model=os.path.join(llm_models_root(), "llama-models-v2", "TinyLlama-1.1B-Chat-v1.0"),
|
||||
kv_cache_config=global_kvcache_config,
|
||||
gather_generation_logits=gather_generation_logits,
|
||||
max_batch_size=128, # reduce buffer sizes, specially for generation logits
|
||||
sampler_type=sampler_type,
|
||||
disable_overlap_scheduler=disable_overlap_scheduler,
|
||||
)
|
||||
|
||||
# FIXME: Sometimes LLM shutdown hangs, might be related to https://nvbugs/5577178.
|
||||
# Remove patch below once fixed.
|
||||
old_exit = LLM.__exit__
|
||||
|
||||
def _exit_with_xfail_on_timeout(self, exc_type, exc_value, traceback) -> bool:
|
||||
import _pytest.outcomes
|
||||
|
||||
try:
|
||||
return old_exit(self, exc_type, exc_value, traceback)
|
||||
except _pytest.outcomes.Failed as e:
|
||||
if e.msg and "pytest-timeout" in e.msg.lower():
|
||||
pytest.xfail("Known LLM shutdown issue (https://nvbugs/5577178).")
|
||||
else:
|
||||
raise
|
||||
|
||||
with pytest.MonkeyPatch.context() as patch:
|
||||
patch.setattr(LLM, "__exit__", _exit_with_xfail_on_timeout)
|
||||
|
||||
with llm:
|
||||
yield llm
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=[False, True])
|
||||
def simple_llm(request) -> LLM:
|
||||
disable_flashinfer_sampling = request.param
|
||||
llm = LLM(
|
||||
model=os.path.join(llm_models_root(), "llama-models-v2", "TinyLlama-1.1B-Chat-v1.0"),
|
||||
max_batch_size=8,
|
||||
disable_flashinfer_sampling=disable_flashinfer_sampling,
|
||||
)
|
||||
return llm
|
||||
|
||||
|
||||
@force_ampere # Save H100 resource
|
||||
@pytest.mark.parametrize("reuse_cache", [False, True])
|
||||
@pytest.mark.parametrize("return_log_probs", [False, True])
|
||||
# FIXME: sometimes LLM shutdown hangs, might be related to https://nvbugs/5577178
|
||||
# NB: Timeout covers fixtures https://github.com/pytest-dev/pytest-timeout/issues/134
|
||||
@pytest.mark.timeout(120, method="signal")
|
||||
@pytest.mark.threadleak(enabled=False)
|
||||
def test_generate_with_return_logits(
|
||||
llm,
|
||||
gather_context_logits_fixture: bool,
|
||||
gather_generation_logits_fixture: bool,
|
||||
reuse_cache: bool,
|
||||
return_log_probs: bool,
|
||||
):
|
||||
gather_context_logits = gather_context_logits_fixture
|
||||
gather_generation_logits = gather_generation_logits_fixture
|
||||
|
||||
if not (gather_context_logits or gather_generation_logits or return_log_probs): # prune space
|
||||
pytest.skip("Nothing to test")
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=8,
|
||||
return_context_logits=gather_context_logits,
|
||||
return_generation_logits=gather_generation_logits,
|
||||
logprobs=return_log_probs,
|
||||
)
|
||||
|
||||
for output in llm.generate(
|
||||
prompts,
|
||||
sampling_params=sampling_params,
|
||||
cache_salt=[CacheSalter.get_salt(reuse_cache) for _ in prompts],
|
||||
):
|
||||
if gather_context_logits:
|
||||
assert output.context_logits is not None
|
||||
# NOTE: prompt_token_ids of "A B C" becomes [1, 319, 350, 315]
|
||||
expected_len = len(prompts[0].split()) + 1
|
||||
try:
|
||||
assert expected_len == output.context_logits.shape[0]
|
||||
except AssertionError:
|
||||
# FIXME: Remove this once the bug has been fixed
|
||||
if gather_context_logits and reuse_cache:
|
||||
pytest.xfail("Known bug: https://nvbugs/5577178")
|
||||
raise
|
||||
else:
|
||||
assert output.context_logits is None
|
||||
|
||||
for sequence in output.outputs:
|
||||
assert sequence.length == sampling_params.max_tokens
|
||||
|
||||
if gather_generation_logits:
|
||||
gen_logits = sequence.generation_logits
|
||||
assert gen_logits is not None
|
||||
assert gen_logits.ndim == 2
|
||||
assert gen_logits.shape[0] == sampling_params.max_tokens
|
||||
assert torch.argmax(gen_logits, dim=1).tolist() == sequence.token_ids
|
||||
else:
|
||||
assert sequence.generation_logits is None
|
||||
|
||||
if return_log_probs:
|
||||
assert len(sequence.logprobs) == sampling_params.max_tokens
|
||||
else:
|
||||
assert len(sequence.logprobs) == 0
|
||||
|
||||
|
||||
@force_ampere # Save H100 resource
|
||||
@pytest.mark.parametrize("reuse_cache", [False, True])
|
||||
@pytest.mark.parametrize("return_log_probs", [False, True])
|
||||
# FIXME: sometimes LLM shutdown hangs, might be related to https://nvbugs/5577178
|
||||
# NB: Timeout covers fixtures https://github.com/pytest-dev/pytest-timeout/issues/134
|
||||
@pytest.mark.timeout(120, method="signal")
|
||||
@pytest.mark.threadleak(enabled=False)
|
||||
def test_generate_async_with_return_logits(
|
||||
llm,
|
||||
gather_context_logits_fixture: bool,
|
||||
gather_generation_logits_fixture: bool,
|
||||
reuse_cache: bool,
|
||||
return_log_probs: bool,
|
||||
):
|
||||
gather_context_logits = gather_context_logits_fixture
|
||||
gather_generation_logits = gather_generation_logits_fixture
|
||||
|
||||
if not (gather_context_logits or gather_generation_logits or return_log_probs): # prune space
|
||||
pytest.skip("Nothing to test")
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=8,
|
||||
return_context_logits=gather_context_logits,
|
||||
return_generation_logits=gather_generation_logits,
|
||||
logprobs=return_log_probs,
|
||||
)
|
||||
|
||||
for idx, output in enumerate(
|
||||
llm.generate_async(
|
||||
prompts[0],
|
||||
sampling_params=sampling_params,
|
||||
streaming=True,
|
||||
cache_salt=CacheSalter.get_salt(reuse_cache),
|
||||
)
|
||||
):
|
||||
if gather_context_logits:
|
||||
assert output.context_logits is not None
|
||||
# NOTE: prompt_token_ids of "A B C" becomes [1, 319, 350, 315]
|
||||
expected_len = len(prompts[0].split()) + 1
|
||||
try:
|
||||
assert expected_len == output.context_logits.shape[0]
|
||||
except AssertionError:
|
||||
# FIXME: Remove this once the bug has been fixed
|
||||
if gather_context_logits and reuse_cache:
|
||||
pytest.xfail("Known bug: https://nvbugs/5577178")
|
||||
raise
|
||||
else:
|
||||
assert output.context_logits is None
|
||||
|
||||
for sequence in output.outputs:
|
||||
assert sequence.length == idx + 1
|
||||
|
||||
if gather_generation_logits:
|
||||
gen_logits = sequence.generation_logits
|
||||
assert gen_logits is not None
|
||||
assert gen_logits.ndim == 2
|
||||
assert gen_logits.shape[0] == 1
|
||||
try:
|
||||
assert torch.argmax(gen_logits, dim=1).tolist()[0] == sequence.token_ids[-1]
|
||||
except AssertionError:
|
||||
# FIXME: Remove xfail once the bug is fixed
|
||||
pytest.xfail("Known bug: https://nvbugs/5573238")
|
||||
else:
|
||||
assert sequence.generation_logits is None
|
||||
|
||||
if return_log_probs:
|
||||
assert len(sequence.logprobs) == idx + 1
|
||||
else:
|
||||
assert len(sequence.logprobs) == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("logprobs_k", [0, 1, 3], ids=["top_0", "top_1", "top_3"])
|
||||
@pytest.mark.parametrize("logprobs_mode", ["raw", "processed"])
|
||||
@pytest.mark.threadleak(enabled=False)
|
||||
def test_sampled_token_always_in_logprobs(logprobs_k: int, logprobs_mode: str, simple_llm: LLM):
|
||||
"""Two scenarios:
|
||||
- logprobs=0: Returns only sampled token (1 element)
|
||||
- logprobs=K (K>0): Returns top-K tokens + sampled token if not in top-K (up to K+1 elements)
|
||||
"""
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=8,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
logprobs=logprobs_k,
|
||||
logprobs_mode=logprobs_mode,
|
||||
)
|
||||
|
||||
for output in simple_llm.generate(["The future of AI is"], sampling_params=sampling_params):
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"Generated text: {output.outputs[0].text!r}")
|
||||
print(f"Generated token IDs: {output.outputs[0].token_ids}")
|
||||
|
||||
logprobs = output.outputs[0].logprobs
|
||||
token_ids = output.outputs[0].token_ids
|
||||
|
||||
assert len(logprobs) == sampling_params.max_tokens, (
|
||||
f"Expected {sampling_params.max_tokens} logprob entries, got {len(logprobs)}"
|
||||
)
|
||||
|
||||
for token_idx, (sampled_token_id, token_logprobs) in enumerate(zip(token_ids, logprobs)):
|
||||
print(
|
||||
f"\n Token {token_idx}: "
|
||||
f"ID={sampled_token_id}, "
|
||||
f"Text={simple_llm.tokenizer.decode([sampled_token_id])!r}"
|
||||
)
|
||||
|
||||
assert sampled_token_id in token_logprobs, (
|
||||
f"Token {token_idx}: Sampled token ID {sampled_token_id} not in logprobs dict: {token_logprobs.keys()}"
|
||||
)
|
||||
|
||||
if logprobs_k == 0:
|
||||
assert len(token_logprobs) == 1, (
|
||||
f"Token {token_idx}: Expected 1 logprob (sampled only), got {len(token_logprobs)}"
|
||||
)
|
||||
else:
|
||||
assert len(token_logprobs) <= logprobs_k + 1, (
|
||||
f"Token {token_idx}: Expected at most {logprobs_k + 1} logprobs, got {len(token_logprobs)}"
|
||||
)
|
||||
assert len(token_logprobs) >= 1
|
||||
|
||||
sorted_tokens_by_prob = sorted(
|
||||
token_logprobs.items(), key=lambda x: x[1].logprob, reverse=True
|
||||
)
|
||||
|
||||
if logprobs_k > 0:
|
||||
sampled_token_rank = token_logprobs[sampled_token_id].rank
|
||||
sampled_in_topk = sampled_token_rank <= logprobs_k
|
||||
|
||||
if not sampled_in_topk:
|
||||
assert sorted_tokens_by_prob[-1][0] == sampled_token_id, (
|
||||
f"Token {token_idx}: Sampled token (ID={sampled_token_id}, rank={sampled_token_rank}) "
|
||||
f"not in top-{logprobs_k}, should be last in sorted list, "
|
||||
f"but last token is ID={sorted_tokens_by_prob[-1][0]}"
|
||||
)
|
||||
|
||||
for rank_idx, (token_id, logprob_obj) in enumerate(sorted_tokens_by_prob, start=1):
|
||||
token_text = simple_llm.tokenizer.decode([token_id])
|
||||
is_sampled = "← SAMPLED" if token_id == sampled_token_id else ""
|
||||
print(
|
||||
f" • Token {token_id:5d} ({token_text:15s}): "
|
||||
f"logprob={logprob_obj.logprob:8.4f}, "
|
||||
f"rank={logprob_obj.rank} {is_sampled}"
|
||||
)
|
||||
|
||||
if logprobs_k > 0 and sampled_in_topk:
|
||||
assert logprob_obj.rank == rank_idx, (
|
||||
f"Token {token_idx}: Token {token_id} rank mismatch. "
|
||||
f"Expected rank {rank_idx} (by sorted position), got {logprob_obj.rank}"
|
||||
)
|
||||
|
||||
print(f"{'=' * 80}\n")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("logprobs_k", [0, 2], ids=["top_0", "top_2"])
|
||||
@pytest.mark.threadleak(enabled=False)
|
||||
def test_logprobs_with_grouped_samplings_strategies(logprobs_k: int, simple_llm: LLM):
|
||||
"""Test logprobs when requests are reordered by sampling strategy grouping"""
|
||||
|
||||
test_prompts = [
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
"Hello, my name is",
|
||||
"Hello, my name is",
|
||||
"Write a short story about a cat",
|
||||
]
|
||||
|
||||
# Causes reordering: [0,1,2,3,4] → [0,2,3,1,4]
|
||||
sampling_params_list = [
|
||||
SamplingParams(
|
||||
max_tokens=6,
|
||||
temperature=0.8,
|
||||
top_k=50,
|
||||
logprobs=logprobs_k,
|
||||
return_generation_logits=True,
|
||||
),
|
||||
SamplingParams(
|
||||
max_tokens=6,
|
||||
temperature=0.8,
|
||||
top_p=0.9,
|
||||
logprobs=logprobs_k,
|
||||
return_generation_logits=True,
|
||||
),
|
||||
SamplingParams(
|
||||
max_tokens=6,
|
||||
temperature=0.8,
|
||||
top_k=50,
|
||||
logprobs=logprobs_k,
|
||||
return_generation_logits=True,
|
||||
),
|
||||
SamplingParams(
|
||||
max_tokens=6, temperature=0.8, top_k=50, logprobs=None, return_generation_logits=True
|
||||
),
|
||||
SamplingParams(
|
||||
max_tokens=6,
|
||||
temperature=0.8,
|
||||
top_p=0.9,
|
||||
logprobs=logprobs_k,
|
||||
return_generation_logits=True,
|
||||
),
|
||||
]
|
||||
|
||||
outputs = list(simple_llm.generate(test_prompts, sampling_params=sampling_params_list))
|
||||
|
||||
for req_idx, output in enumerate(outputs):
|
||||
generation_logits = output.outputs[0].generation_logits.to(device="cuda")
|
||||
token_ids = output.outputs[0].token_ids
|
||||
logprobs = output.outputs[0].logprobs
|
||||
if sampling_params_list[req_idx].logprobs is None:
|
||||
assert len(logprobs) == 0
|
||||
continue
|
||||
|
||||
assert generation_logits is not None
|
||||
assert len(logprobs) == len(token_ids), "Logprobs length mismatch"
|
||||
|
||||
# generation_logits might be shorter than token_ids
|
||||
num_logits = len(generation_logits)
|
||||
|
||||
for token_idx, (sampled_token_id, token_logprobs_dict) in enumerate(
|
||||
zip(token_ids[:num_logits], logprobs[:num_logits])
|
||||
):
|
||||
returned_logprob = token_logprobs_dict[sampled_token_id].logprob
|
||||
|
||||
logits_for_token = generation_logits[token_idx]
|
||||
expected_logprobs = torch.nn.functional.log_softmax(logits_for_token, dim=-1).to(
|
||||
device="cpu"
|
||||
)
|
||||
expected_logprob = expected_logprobs[sampled_token_id].item()
|
||||
print(
|
||||
f"Req {req_idx}, Token {token_idx}: returned={returned_logprob:.6f}, expected={expected_logprob:.6f}"
|
||||
)
|
||||
torch.testing.assert_close(returned_logprob, expected_logprob)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("logprobs_k", [0, 2], ids=["top_0", "top_2"])
|
||||
@pytest.mark.threadleak(enabled=False)
|
||||
def test_processed_logprobs_e2e(logprobs_k: int, simple_llm: LLM):
|
||||
"""Test logprobs when requests are reordered by sampling strategy grouping"""
|
||||
test_prompts = [
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
"Hello, my name is",
|
||||
"Write a short story about a cat",
|
||||
"Hello, my name is",
|
||||
"Write a short story about a cat",
|
||||
]
|
||||
|
||||
sampling_params_list = [
|
||||
# greedy decoding
|
||||
SamplingParams(
|
||||
max_tokens=6,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs_k,
|
||||
return_generation_logits=True,
|
||||
logprobs_mode="processed",
|
||||
),
|
||||
# temperature sampling
|
||||
SamplingParams(
|
||||
max_tokens=6,
|
||||
temperature=0.8,
|
||||
logprobs=logprobs_k,
|
||||
return_generation_logits=True,
|
||||
logprobs_mode="processed",
|
||||
),
|
||||
# top-p sampling
|
||||
SamplingParams(
|
||||
max_tokens=6,
|
||||
temperature=0.8,
|
||||
top_p=0.9,
|
||||
logprobs=logprobs_k,
|
||||
return_generation_logits=True,
|
||||
logprobs_mode="processed",
|
||||
),
|
||||
# top-k sampling
|
||||
SamplingParams(
|
||||
max_tokens=6,
|
||||
temperature=0.8,
|
||||
top_k=50,
|
||||
logprobs=logprobs_k,
|
||||
return_generation_logits=True,
|
||||
logprobs_mode="processed",
|
||||
),
|
||||
# top-p sampling 2
|
||||
SamplingParams(
|
||||
max_tokens=6,
|
||||
temperature=0.8,
|
||||
top_p=0.9,
|
||||
logprobs=logprobs_k,
|
||||
return_generation_logits=True,
|
||||
logprobs_mode="processed",
|
||||
),
|
||||
# top-p and top-k sampling
|
||||
SamplingParams(
|
||||
max_tokens=6,
|
||||
temperature=0.8,
|
||||
top_p=0.9,
|
||||
top_k=50,
|
||||
logprobs=logprobs_k,
|
||||
return_generation_logits=True,
|
||||
logprobs_mode="processed",
|
||||
),
|
||||
]
|
||||
|
||||
outputs = list(simple_llm.generate(test_prompts, sampling_params=sampling_params_list))
|
||||
|
||||
for req_idx, output in enumerate(outputs):
|
||||
generation_logits = output.outputs[0].generation_logits.to(device="cuda")
|
||||
token_ids = output.outputs[0].token_ids
|
||||
logprobs = output.outputs[0].logprobs
|
||||
|
||||
assert generation_logits is not None
|
||||
assert len(logprobs) == len(token_ids), "Logprobs length mismatch"
|
||||
|
||||
# generation_logits might be shorter than token_ids
|
||||
num_logits = len(generation_logits)
|
||||
|
||||
for token_idx, token_logprobs_dict in enumerate(logprobs[:num_logits]):
|
||||
assert token_ids[token_idx] in token_logprobs_dict, "Sampled token not in logprobs"
|
||||
|
||||
logits_for_token = generation_logits[token_idx : token_idx + 1]
|
||||
topk = sampling_params_list[req_idx].top_k
|
||||
topp = sampling_params_list[req_idx].top_p
|
||||
temperature = sampling_params_list[req_idx].temperature
|
||||
if sampling_params_list[req_idx]._greedy_decoding:
|
||||
probs = torch.zeros_like(logits_for_token)
|
||||
probs[0, token_ids[token_idx]] = 1.0
|
||||
else:
|
||||
topk = topk if topk is not None else logits_for_token.shape[-1]
|
||||
topp = topp if topp is not None else 1.0
|
||||
temperature = temperature if temperature is not None else 1.0
|
||||
|
||||
# perform maksing top-k top-p
|
||||
if simple_llm.args.disable_flashinfer_sampling:
|
||||
_, probs = top_k_top_p_sampling_batch(
|
||||
logits_for_token, top_k=topk, top_p=topp, temperature=temperature
|
||||
)
|
||||
else:
|
||||
_, probs = _StrategyImpls.StrategyImplWithProbs._sample_with_probs(
|
||||
logits_for_token,
|
||||
group_logit_indices=None,
|
||||
top_k=torch.tensor([topk], dtype=torch.int32, device="cuda"),
|
||||
top_p=torch.tensor([topp], dtype=torch.float32, device="cuda"),
|
||||
temperature=torch.tensor([temperature], dtype=torch.float32, device="cuda"),
|
||||
generator=None,
|
||||
)
|
||||
|
||||
if temperature != 0:
|
||||
logits_for_token /= temperature
|
||||
adjusted_logits_for_token = torch.where(probs != 0, logits_for_token, float("-inf"))[0]
|
||||
expected_logprobs = torch.nn.functional.log_softmax(
|
||||
adjusted_logits_for_token, dim=-1
|
||||
).to(device="cpu")
|
||||
for logprob_token, logprob_values in token_logprobs_dict.items():
|
||||
expected_logprob = expected_logprobs[logprob_token].item()
|
||||
returned_logprob = logprob_values.logprob
|
||||
print(
|
||||
f"Req {req_idx}, Token {token_idx}: "
|
||||
f"returned={returned_logprob:.6f}, expected={expected_logprob:.6f}"
|
||||
)
|
||||
torch.testing.assert_close(returned_logprob, expected_logprob)
|
||||
|
||||
|
||||
@force_ampere
|
||||
@pytest.mark.gpu2
|
||||
def test_logprobs_match_hf_tp2():
|
||||
model_path = os.path.join(llm_models_root(), "llama-models-v2", "TinyLlama-1.1B-Chat-v1.0")
|
||||
llm = LLM(
|
||||
model=model_path,
|
||||
tensor_parallel_size=2,
|
||||
)
|
||||
|
||||
prompts = ["The future of the AI is"]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=10,
|
||||
temperature=1.0,
|
||||
logprobs=0,
|
||||
)
|
||||
|
||||
hf_model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16).to(
|
||||
"cuda"
|
||||
)
|
||||
hf_tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
|
||||
output = list(llm.generate(prompts, sampling_params=sampling_params))[0]
|
||||
|
||||
trtllm_token_ids = output.outputs[0].token_ids
|
||||
trtllm_logprobs = torch.tensor(
|
||||
[list(lp.values())[0].logprob for lp in output.outputs[0].logprobs]
|
||||
)
|
||||
|
||||
base_ids = hf_tokenizer.encode(prompts[0], return_tensors="pt").to("cuda")
|
||||
hf_logprobs = []
|
||||
|
||||
for i, token_id in enumerate(trtllm_token_ids):
|
||||
if i > 0:
|
||||
prev_tokens = torch.tensor([trtllm_token_ids[:i]], device="cuda")
|
||||
input_ids = torch.cat([base_ids, prev_tokens], dim=1)
|
||||
else:
|
||||
input_ids = base_ids
|
||||
with torch.no_grad():
|
||||
logits = hf_model(input_ids).logits[0, -1, :]
|
||||
hf_logprobs.append(torch.log_softmax(logits, dim=-1)[token_id].item())
|
||||
|
||||
hf_logprobs = torch.tensor(hf_logprobs)
|
||||
|
||||
print(f"\nTensorRT-LLM logprobs: {trtllm_logprobs}")
|
||||
print(f"HuggingFace logprobs: {hf_logprobs}")
|
||||
print(f"Diff: {(trtllm_logprobs - hf_logprobs).abs()}")
|
||||
|
||||
torch.testing.assert_close(trtllm_logprobs, hf_logprobs, atol=0.15, rtol=0)
|
||||
@ -1,239 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from utils.llm_data import llm_models_root
|
||||
from utils.util import force_ampere
|
||||
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
from tensorrt_llm.llmapi.llm_utils import KvCacheConfig
|
||||
|
||||
prompts = ["A B C"]
|
||||
global_kvcache_config = KvCacheConfig(
|
||||
max_tokens=10000,
|
||||
enable_block_reuse=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=[False, True])
|
||||
def gather_generation_logits_fixture(request) -> bool:
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=[False, True])
|
||||
def gather_context_logits_fixture(request) -> bool:
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=[False, True])
|
||||
def disable_overlap_scheduler_fixture(request) -> bool:
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=["TRTLLMSampler", "TorchSampler"])
|
||||
def sampler_type_fixture(request) -> str:
|
||||
return request.param
|
||||
|
||||
|
||||
class CacheSalter:
|
||||
|
||||
_salt = 0
|
||||
|
||||
@classmethod
|
||||
def get_salt_unique(cls) -> str:
|
||||
cls._salt += 1
|
||||
return str(cls._salt)
|
||||
|
||||
@classmethod
|
||||
def get_salt_shared(cls) -> str:
|
||||
return str(0)
|
||||
|
||||
@classmethod
|
||||
def get_salt(cls, reuse_cache: bool) -> str:
|
||||
if reuse_cache:
|
||||
salt = cls.get_salt_shared()
|
||||
else:
|
||||
salt = cls.get_salt_unique()
|
||||
return salt
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def llm(
|
||||
gather_context_logits_fixture: bool,
|
||||
gather_generation_logits_fixture: bool,
|
||||
sampler_type_fixture: str,
|
||||
disable_overlap_scheduler_fixture: bool,
|
||||
):
|
||||
gather_generation_logits = gather_generation_logits_fixture
|
||||
sampler_type = sampler_type_fixture
|
||||
disable_overlap_scheduler = disable_overlap_scheduler_fixture
|
||||
|
||||
llm = LLM(
|
||||
model=os.path.join(llm_models_root(), "llama-models-v2",
|
||||
"TinyLlama-1.1B-Chat-v1.0"),
|
||||
kv_cache_config=global_kvcache_config,
|
||||
gather_generation_logits=gather_generation_logits,
|
||||
max_batch_size=
|
||||
128, # reduce buffer sizes, specially for generation logits
|
||||
sampler_type=sampler_type,
|
||||
disable_overlap_scheduler=disable_overlap_scheduler,
|
||||
)
|
||||
|
||||
# FIXME: Sometimes LLM shutdown hangs, might be related to https://nvbugs/5577178.
|
||||
# Remove patch below once fixed.
|
||||
old_exit = LLM.__exit__
|
||||
|
||||
def _exit_with_xfail_on_timeout(self, exc_type, exc_value,
|
||||
traceback) -> bool:
|
||||
import _pytest.outcomes
|
||||
try:
|
||||
return old_exit(self, exc_type, exc_value, traceback)
|
||||
except _pytest.outcomes.Failed as e:
|
||||
if e.msg and "pytest-timeout" in e.msg.lower():
|
||||
pytest.xfail(
|
||||
"Known LLM shutdown issue (https://nvbugs/5577178).")
|
||||
else:
|
||||
raise
|
||||
|
||||
with pytest.MonkeyPatch.context() as patch:
|
||||
patch.setattr(LLM, "__exit__", _exit_with_xfail_on_timeout)
|
||||
|
||||
with llm:
|
||||
yield llm
|
||||
|
||||
|
||||
@force_ampere # Save H100 resource
|
||||
@pytest.mark.parametrize("reuse_cache", [False, True])
|
||||
@pytest.mark.parametrize("return_log_probs", [False, True])
|
||||
# FIXME: sometimes LLM shutdown hangs, might be related to https://nvbugs/5577178
|
||||
# NB: Timeout covers fixtures https://github.com/pytest-dev/pytest-timeout/issues/134
|
||||
@pytest.mark.timeout(120, method="signal")
|
||||
@pytest.mark.threadleak(enabled=False)
|
||||
def test_generate_with_return_logits(
|
||||
llm,
|
||||
gather_context_logits_fixture: bool,
|
||||
gather_generation_logits_fixture: bool,
|
||||
reuse_cache: bool,
|
||||
return_log_probs: bool,
|
||||
):
|
||||
gather_context_logits = gather_context_logits_fixture
|
||||
gather_generation_logits = gather_generation_logits_fixture
|
||||
|
||||
if not (gather_context_logits or gather_generation_logits
|
||||
or return_log_probs): # prune space
|
||||
pytest.skip("Nothing to test")
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=8,
|
||||
return_context_logits=gather_context_logits,
|
||||
return_generation_logits=gather_generation_logits,
|
||||
logprobs=return_log_probs,
|
||||
)
|
||||
|
||||
for output in llm.generate(
|
||||
prompts,
|
||||
sampling_params=sampling_params,
|
||||
cache_salt=[CacheSalter.get_salt(reuse_cache) for _ in prompts],
|
||||
):
|
||||
if gather_context_logits:
|
||||
assert output.context_logits is not None
|
||||
# NOTE: prompt_token_ids of "A B C" becomes [1, 319, 350, 315]
|
||||
expected_len = len(prompts[0].split()) + 1
|
||||
try:
|
||||
assert expected_len == output.context_logits.shape[0]
|
||||
except AssertionError:
|
||||
# FIXME: Remove this once the bug has been fixed
|
||||
if gather_context_logits and reuse_cache:
|
||||
pytest.xfail("Known bug: https://nvbugs/5577178")
|
||||
raise
|
||||
else:
|
||||
assert output.context_logits is None
|
||||
|
||||
for sequence in output.outputs:
|
||||
assert sequence.length == sampling_params.max_tokens
|
||||
|
||||
if gather_generation_logits:
|
||||
gen_logits = sequence.generation_logits
|
||||
assert gen_logits is not None
|
||||
assert gen_logits.ndim == 2
|
||||
assert gen_logits.shape[0] == sampling_params.max_tokens
|
||||
assert torch.argmax(gen_logits,
|
||||
dim=1).tolist() == sequence.token_ids
|
||||
else:
|
||||
assert sequence.generation_logits is None
|
||||
|
||||
if return_log_probs:
|
||||
assert len(sequence.logprobs) == sampling_params.max_tokens
|
||||
else:
|
||||
assert len(sequence.logprobs) == 0
|
||||
|
||||
|
||||
@force_ampere # Save H100 resource
|
||||
@pytest.mark.parametrize("reuse_cache", [False, True])
|
||||
@pytest.mark.parametrize("return_log_probs", [False, True])
|
||||
# FIXME: sometimes LLM shutdown hangs, might be related to https://nvbugs/5577178
|
||||
# NB: Timeout covers fixtures https://github.com/pytest-dev/pytest-timeout/issues/134
|
||||
@pytest.mark.timeout(120, method="signal")
|
||||
@pytest.mark.threadleak(enabled=False)
|
||||
def test_generate_async_with_return_logits(
|
||||
llm,
|
||||
gather_context_logits_fixture: bool,
|
||||
gather_generation_logits_fixture: bool,
|
||||
reuse_cache: bool,
|
||||
return_log_probs: bool,
|
||||
):
|
||||
gather_context_logits = gather_context_logits_fixture
|
||||
gather_generation_logits = gather_generation_logits_fixture
|
||||
|
||||
if not (gather_context_logits or gather_generation_logits
|
||||
or return_log_probs): # prune space
|
||||
pytest.skip("Nothing to test")
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=8,
|
||||
return_context_logits=gather_context_logits,
|
||||
return_generation_logits=gather_generation_logits,
|
||||
logprobs=return_log_probs)
|
||||
|
||||
for idx, output in enumerate(
|
||||
llm.generate_async(
|
||||
prompts[0],
|
||||
sampling_params=sampling_params,
|
||||
streaming=True,
|
||||
cache_salt=CacheSalter.get_salt(reuse_cache),
|
||||
)):
|
||||
if gather_context_logits:
|
||||
assert output.context_logits is not None
|
||||
# NOTE: prompt_token_ids of "A B C" becomes [1, 319, 350, 315]
|
||||
expected_len = len(prompts[0].split()) + 1
|
||||
try:
|
||||
assert expected_len == output.context_logits.shape[0]
|
||||
except AssertionError:
|
||||
# FIXME: Remove this once the bug has been fixed
|
||||
if gather_context_logits and reuse_cache:
|
||||
pytest.xfail("Known bug: https://nvbugs/5577178")
|
||||
raise
|
||||
else:
|
||||
assert output.context_logits is None
|
||||
|
||||
for sequence in output.outputs:
|
||||
assert sequence.length == idx + 1
|
||||
|
||||
if gather_generation_logits:
|
||||
gen_logits = sequence.generation_logits
|
||||
assert gen_logits is not None
|
||||
assert gen_logits.ndim == 2
|
||||
assert gen_logits.shape[0] == 1
|
||||
try:
|
||||
assert torch.argmax(
|
||||
gen_logits, dim=1).tolist()[0] == sequence.token_ids[-1]
|
||||
except AssertionError:
|
||||
# FIXME: Remove xfail once the bug is fixed
|
||||
pytest.xfail("Known bug: https://nvbugs/5573238")
|
||||
else:
|
||||
assert sequence.generation_logits is None
|
||||
|
||||
if return_log_probs:
|
||||
assert len(sequence.logprobs) == idx + 1
|
||||
else:
|
||||
assert len(sequence.logprobs) == 0
|
||||
@ -86,7 +86,7 @@ class TestStrategySelection:
|
||||
is_context_init_state: bool # Torch sampler accesses this, but it does not affect this test
|
||||
|
||||
def get_beam_width_by_iter(
|
||||
self, for_next_iteration: bool
|
||||
self, for_next_iteration: bool = False
|
||||
) -> int: # Torch sampler accesses this, but it does not affect this test
|
||||
return self.sampling_config.beam_width
|
||||
|
||||
@ -445,12 +445,22 @@ def test_select_generated_logits(draft_len: int, with_ctx: bool, with_gen: bool)
|
||||
def py_return_context_logits(self) -> bool:
|
||||
return self._return_context_logits
|
||||
|
||||
def get_beam_width_by_iter(
|
||||
self, for_next_iteration: bool = False
|
||||
) -> int: # Torch sampler accesses this, but it does not affect this test
|
||||
return self.sampling_config.beam_width
|
||||
|
||||
class GenRequestMock:
|
||||
def __init__(self, draft_len: int):
|
||||
self.is_context_init_state = False
|
||||
self.py_draft_tokens = torch.empty(draft_len, dtype=torch.int32, device=device)
|
||||
self.sampling_config = SamplingConfig(beam_width=1)
|
||||
|
||||
def get_beam_width_by_iter(
|
||||
self, for_next_iteration: bool = False
|
||||
) -> int: # Torch sampler accesses this, but it does not affect this test
|
||||
return self.sampling_config.beam_width
|
||||
|
||||
class ScheduledRequestsMock:
|
||||
@property
|
||||
def context_requests(self) -> list[LlmRequest]:
|
||||
|
||||
@ -31,6 +31,7 @@ from tensorrt_llm.llmapi import (CalibConfig, CompletionOutput,
|
||||
from tensorrt_llm.llmapi.llm_args import SamplerType
|
||||
from tensorrt_llm.llmapi.llm_utils import LlmArgs
|
||||
from tensorrt_llm.logger import Singleton
|
||||
from tensorrt_llm.sampling_params import LogprobMode
|
||||
|
||||
|
||||
def repr_annotation(field_type: type) -> str:
|
||||
|
||||
@ -15,5 +15,8 @@ methods:
|
||||
prompt_ignore_length:
|
||||
annotation: Optional[int]
|
||||
default: null
|
||||
logprobs_mode:
|
||||
annotation: LogprobMode
|
||||
default: LogprobMode.RAW
|
||||
return_annotation: None
|
||||
properties: {}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user