[https://nvbugs/5513423][fix] Correctly respect min_tokens in PyTorch Workflow (#7808)

Signed-off-by: Stefan Niebler <82932102+stnie@users.noreply.github.com>
Co-authored-by: Daniel Cámpora <961215+dcampora@users.noreply.github.com>
This commit is contained in:
Stefan Niebler 2025-09-22 07:15:18 +02:00 committed by GitHub
parent 9dc7316b7f
commit 8aead224fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 56 additions and 8 deletions

View File

@ -327,6 +327,7 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
self.py_prompt_len = self.prompt_len
self.py_orig_prompt_len = self.orig_prompt_len
self.py_max_new_tokens = self.max_new_tokens
self.py_min_length = self.sampling_config.min_length
self.py_batch_idx = None
self.py_draft_pages_allocated = 0
self.py_rewind_len = 0

View File

@ -657,6 +657,36 @@ class TorchSampler(Sampler):
return logits
@staticmethod
@torch.inference_mode()
def _apply_min_length_penalty(logits: torch.Tensor,
requests: list[LlmRequest],
num_steps: list[int]) -> torch.Tensor:
"""Inplace apply min_length_penalty to logits.
Args:
logits: The logits to apply min length penalty to
requests: The requests to apply min length penalty to
num_steps: The number of steps per request
Returns:
The logits with min length penalty applied
"""
if any(r.py_min_length and r.max_beam_num_tokens < r.py_min_length[0]
for r in requests):
current_offset = 0
for index, r in enumerate(requests):
if r.py_min_length:
for step in range(num_steps[index]):
if r.max_beam_num_tokens + step < r.py_min_length[0]:
logits[current_offset + step,
r.py_end_id] = float('-inf')
else:
#early exit
break
current_offset += num_steps[index]
return logits
def _process_requests(self,
scheduled_requests: ScheduledRequests,
model_outputs: dict[str, torch.Tensor],
@ -686,6 +716,8 @@ class TorchSampler(Sampler):
requests = scheduled_requests.all_requests()
num_steps = [1 + get_draft_token_length(req) for req in requests]
raw_logits = self._apply_min_length_penalty(raw_logits, requests,
num_steps)
sum_steps = sum(num_steps)
no_draft_tokens = len(requests) == sum_steps
fast_path = not self.enable_mixed_sampler and no_draft_tokens and log_probs_host is None

View File

@ -6,7 +6,7 @@ import pytest
from tensorrt_llm import LLM
from tensorrt_llm.executor import GenerationExecutorWorker
from tensorrt_llm.llmapi import KvCacheConfig
from tensorrt_llm.llmapi.llm_args import PeftCacheConfig
from tensorrt_llm.llmapi.llm_args import NGramDecodingConfig, PeftCacheConfig
from tensorrt_llm.llmapi.tokenizer import TransformersTokenizer
from tensorrt_llm.metrics import MetricNames
from tensorrt_llm.sampling_params import SamplingParams
@ -861,15 +861,30 @@ def test_llm_with_proxy_error():
@pytest.mark.part0
@pytest.mark.xfail(reason="https://nvbugs/5513423")
def test_min_tokens():
@pytest.mark.parametrize("use_speculative", [True, False])
def test_min_tokens(use_speculative: bool):
"""Check min_tokens is respected."""
llm = LLM(model=llama_model_path,
kv_cache_config=global_kvcache_config,
enable_mixed_sampler=True,
max_seq_len=20000)
llm_common_config = dict(
model=llama_model_path,
max_batch_size=2,
kv_cache_config=global_kvcache_config,
max_num_tokens=2048,
enable_mixed_sampler=True,
)
output_len = 5000
if use_speculative:
spec_config = NGramDecodingConfig(
max_draft_len=4,
max_matching_ngram_size=2,
is_keep_all=True,
is_use_oldest=True,
is_public_pool=True,
)
llm = LLM(**llm_common_config, speculative_config=spec_config)
else:
llm = LLM(**llm_common_config)
output_len = 2000
sampling_params = SamplingParams(max_tokens=output_len,
min_tokens=output_len,
temperature=1)