mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[https://nvbugs/5513423][fix] Correctly respect min_tokens in PyTorch Workflow (#7808)
Signed-off-by: Stefan Niebler <82932102+stnie@users.noreply.github.com> Co-authored-by: Daniel Cámpora <961215+dcampora@users.noreply.github.com>
This commit is contained in:
parent
9dc7316b7f
commit
8aead224fb
@ -327,6 +327,7 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
|
||||
self.py_prompt_len = self.prompt_len
|
||||
self.py_orig_prompt_len = self.orig_prompt_len
|
||||
self.py_max_new_tokens = self.max_new_tokens
|
||||
self.py_min_length = self.sampling_config.min_length
|
||||
self.py_batch_idx = None
|
||||
self.py_draft_pages_allocated = 0
|
||||
self.py_rewind_len = 0
|
||||
|
||||
@ -657,6 +657,36 @@ class TorchSampler(Sampler):
|
||||
|
||||
return logits
|
||||
|
||||
@staticmethod
|
||||
@torch.inference_mode()
|
||||
def _apply_min_length_penalty(logits: torch.Tensor,
|
||||
requests: list[LlmRequest],
|
||||
num_steps: list[int]) -> torch.Tensor:
|
||||
"""Inplace apply min_length_penalty to logits.
|
||||
|
||||
Args:
|
||||
logits: The logits to apply min length penalty to
|
||||
requests: The requests to apply min length penalty to
|
||||
num_steps: The number of steps per request
|
||||
|
||||
Returns:
|
||||
The logits with min length penalty applied
|
||||
"""
|
||||
if any(r.py_min_length and r.max_beam_num_tokens < r.py_min_length[0]
|
||||
for r in requests):
|
||||
current_offset = 0
|
||||
for index, r in enumerate(requests):
|
||||
if r.py_min_length:
|
||||
for step in range(num_steps[index]):
|
||||
if r.max_beam_num_tokens + step < r.py_min_length[0]:
|
||||
logits[current_offset + step,
|
||||
r.py_end_id] = float('-inf')
|
||||
else:
|
||||
#early exit
|
||||
break
|
||||
current_offset += num_steps[index]
|
||||
return logits
|
||||
|
||||
def _process_requests(self,
|
||||
scheduled_requests: ScheduledRequests,
|
||||
model_outputs: dict[str, torch.Tensor],
|
||||
@ -686,6 +716,8 @@ class TorchSampler(Sampler):
|
||||
|
||||
requests = scheduled_requests.all_requests()
|
||||
num_steps = [1 + get_draft_token_length(req) for req in requests]
|
||||
raw_logits = self._apply_min_length_penalty(raw_logits, requests,
|
||||
num_steps)
|
||||
sum_steps = sum(num_steps)
|
||||
no_draft_tokens = len(requests) == sum_steps
|
||||
fast_path = not self.enable_mixed_sampler and no_draft_tokens and log_probs_host is None
|
||||
|
||||
@ -6,7 +6,7 @@ import pytest
|
||||
from tensorrt_llm import LLM
|
||||
from tensorrt_llm.executor import GenerationExecutorWorker
|
||||
from tensorrt_llm.llmapi import KvCacheConfig
|
||||
from tensorrt_llm.llmapi.llm_args import PeftCacheConfig
|
||||
from tensorrt_llm.llmapi.llm_args import NGramDecodingConfig, PeftCacheConfig
|
||||
from tensorrt_llm.llmapi.tokenizer import TransformersTokenizer
|
||||
from tensorrt_llm.metrics import MetricNames
|
||||
from tensorrt_llm.sampling_params import SamplingParams
|
||||
@ -861,15 +861,30 @@ def test_llm_with_proxy_error():
|
||||
|
||||
|
||||
@pytest.mark.part0
|
||||
@pytest.mark.xfail(reason="https://nvbugs/5513423")
|
||||
def test_min_tokens():
|
||||
@pytest.mark.parametrize("use_speculative", [True, False])
|
||||
def test_min_tokens(use_speculative: bool):
|
||||
"""Check min_tokens is respected."""
|
||||
llm = LLM(model=llama_model_path,
|
||||
kv_cache_config=global_kvcache_config,
|
||||
enable_mixed_sampler=True,
|
||||
max_seq_len=20000)
|
||||
llm_common_config = dict(
|
||||
model=llama_model_path,
|
||||
max_batch_size=2,
|
||||
kv_cache_config=global_kvcache_config,
|
||||
max_num_tokens=2048,
|
||||
enable_mixed_sampler=True,
|
||||
)
|
||||
|
||||
output_len = 5000
|
||||
if use_speculative:
|
||||
spec_config = NGramDecodingConfig(
|
||||
max_draft_len=4,
|
||||
max_matching_ngram_size=2,
|
||||
is_keep_all=True,
|
||||
is_use_oldest=True,
|
||||
is_public_pool=True,
|
||||
)
|
||||
llm = LLM(**llm_common_config, speculative_config=spec_config)
|
||||
else:
|
||||
llm = LLM(**llm_common_config)
|
||||
|
||||
output_len = 2000
|
||||
sampling_params = SamplingParams(max_tokens=output_len,
|
||||
min_tokens=output_len,
|
||||
temperature=1)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user