mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-19 17:25:17 +08:00
[NvBug 5370718, 5371538] fix: Fix incremental detokenization (#5825)
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
This commit is contained in:
parent
dc32f9ae73
commit
055c4a9fe6
@ -376,17 +376,16 @@ class DetokenizedGenerationResultBase(GenerationResultBase):
|
||||
if self.sampling_params.detokenize and self.tokenizer is not None:
|
||||
for beam_output in self.outputs:
|
||||
beam_output._last_text_len = len(beam_output.text)
|
||||
if hasattr(self.tokenizer, 'decode_incrementally'):
|
||||
if self._streaming and not self.sampling_params.use_beam_search:
|
||||
beam_output.text, beam_output._incremental_states = self.tokenizer.decode_incrementally(
|
||||
beam_output.token_ids_diff,
|
||||
prev_text=beam_output.text,
|
||||
states=beam_output._incremental_states,
|
||||
flush=self._done,
|
||||
**kwargs)
|
||||
else:
|
||||
beam_output.text, _ = self.tokenizer.decode_incrementally(
|
||||
beam_output.token_ids, flush=self._done, **kwargs)
|
||||
if hasattr(
|
||||
self.tokenizer, 'decode_incrementally'
|
||||
) and self._streaming and not self.sampling_params.use_beam_search:
|
||||
beam_output.text, beam_output._incremental_states = self.tokenizer.decode_incrementally(
|
||||
beam_output.token_ids_diff,
|
||||
prev_text=beam_output.text,
|
||||
states=beam_output._incremental_states,
|
||||
flush=self._done,
|
||||
stream_interval=self.sampling_params._stream_interval,
|
||||
**kwargs)
|
||||
else:
|
||||
beam_output.text = self.tokenizer.decode(
|
||||
beam_output.token_ids, **kwargs)
|
||||
|
||||
@ -499,8 +499,8 @@ class BaseLLM:
|
||||
raise ValueError(
|
||||
"tokenizer is required to initialize a default sampling_params, or you can explicitly specify a sampling_params"
|
||||
)
|
||||
return SamplingParams(end_id=self.tokenizer.eos_token_id,
|
||||
pad_id=self.tokenizer.pad_token_id)
|
||||
sampling_params = SamplingParams(end_id=self.tokenizer.eos_token_id,
|
||||
pad_id=self.tokenizer.pad_token_id)
|
||||
elif isinstance(sampling_params, SamplingParams):
|
||||
if sampling_params.end_id is None:
|
||||
if self.tokenizer is None:
|
||||
@ -508,21 +508,26 @@ class BaseLLM:
|
||||
"tokenizer is required to reset end_id if it is None, or you can explicitly specify the end_id for sampling_params"
|
||||
)
|
||||
sampling_params._setup(self.tokenizer)
|
||||
# auto enabled context and/or generation logits flags, as they are required by logprob computation for TRT backend.
|
||||
if self.args.backend not in ["pytorch", "_autodeploy"]:
|
||||
if sampling_params.prompt_logprobs and not sampling_params.return_context_logits:
|
||||
sampling_params.return_context_logits = True
|
||||
sampling_params._context_logits_auto_enabled = True
|
||||
if sampling_params.logprobs and not sampling_params.return_generation_logits:
|
||||
sampling_params.return_generation_logits = True
|
||||
sampling_params._generation_logits_auto_enabled = True
|
||||
|
||||
return sampling_params
|
||||
else:
|
||||
raise TypeError(
|
||||
f"The sampling_params must be type SamplingParams or None, but got {type(sampling_params)}"
|
||||
)
|
||||
|
||||
# auto enabled context and/or generation logits flags, as they are required by logprob computation for TRT backend.
|
||||
if self.args.backend not in ["pytorch", "_autodeploy"]:
|
||||
if sampling_params.prompt_logprobs and not sampling_params.return_context_logits:
|
||||
sampling_params.return_context_logits = True
|
||||
sampling_params._context_logits_auto_enabled = True
|
||||
if sampling_params.logprobs and not sampling_params.return_generation_logits:
|
||||
sampling_params.return_generation_logits = True
|
||||
sampling_params._generation_logits_auto_enabled = True
|
||||
|
||||
if sampling_params._stream_interval is None:
|
||||
sampling_params._stream_interval = getattr(self.args,
|
||||
"stream_interval", 1)
|
||||
|
||||
return sampling_params
|
||||
|
||||
def _check_arguments(self, prompt_len: int, query_len: int,
|
||||
sampling_params: SamplingParams) -> None:
|
||||
|
||||
|
||||
@ -2,11 +2,23 @@ import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from tokenizers.decoders import DecodeStream
|
||||
from transformers import (AutoTokenizer, PreTrainedTokenizerBase,
|
||||
PreTrainedTokenizerFast)
|
||||
|
||||
from .._utils import nvtx_range_debug
|
||||
from ..logger import logger
|
||||
|
||||
TLLM_INCREMENTAL_DETOKENIZATION_BACKEND = os.environ.get(
|
||||
"TLLM_INCREMENTAL_DETOKENIZATION_BACKEND", "HF")
|
||||
TLLM_STREAM_INTERVAL_THRESHOLD = int(
|
||||
os.environ.get("TLLM_STREAM_INTERVAL_THRESHOLD", "24"))
|
||||
try:
|
||||
from tokenizers.decoders import DecodeStream # noqa
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
f"HF incremental detokenization is unsupported by tokenizer<0.21.0; fallback to TRTLLM incremental detokenization."
|
||||
)
|
||||
TLLM_INCREMENTAL_DETOKENIZATION_BACKEND = "TRTLLM"
|
||||
|
||||
|
||||
class TokenizerBase(PreTrainedTokenizerBase):
|
||||
@ -20,9 +32,6 @@ class TransformersTokenizer(TokenizerBase):
|
||||
def __init__(self, tokenizer):
|
||||
self.tokenizer = tokenizer
|
||||
self._all_special_tokens_set = set(self.tokenizer.all_special_tokens)
|
||||
self.hf_decode_stream = None
|
||||
self.stream_interval_threshold = int(
|
||||
os.getenv("TLLM_STREAM_INTERVAL_THRESHOLD", "32"))
|
||||
|
||||
def __call__(self, text: str, *args, **kwargs) -> Any:
|
||||
return self.tokenizer(text, *args, **kwargs)
|
||||
@ -129,8 +138,9 @@ class TransformersTokenizer(TokenizerBase):
|
||||
*,
|
||||
flush: bool = False,
|
||||
skip_special_tokens: bool = False,
|
||||
clean_up_tokenization_spaces: bool = None,
|
||||
spaces_between_special_tokens: bool = True) -> Tuple[str, dict]:
|
||||
clean_up_tokenization_spaces: Optional[bool] = None,
|
||||
spaces_between_special_tokens: bool = True,
|
||||
stream_interval: int = 1) -> Tuple[str, dict]:
|
||||
"""Incremental detokenization, typically used for streaming generation.
|
||||
|
||||
Args:
|
||||
@ -141,22 +151,47 @@ class TransformersTokenizer(TokenizerBase):
|
||||
skip_special_tokens (bool): Whether to remove special tokens in the decoding.
|
||||
clean_up_tokenization_spaces (bool): Whether to clean up tokenization spaces.
|
||||
spaces_between_special_tokens (bool): Whether to add spaces between special tokens.
|
||||
stream_interval (int): The iteration interval to create responses under the streaming mode.
|
||||
|
||||
Returns:
|
||||
text, states (Tuple[str, dict]): text is the current decoded text, states is the current incremental detokenization states.
|
||||
They should be passed to next incremental detokenization iteration, if any.
|
||||
"""
|
||||
# HF incremental detokenization implementation is faster than TRTLLM when stream_interval is smaller.
|
||||
if (TLLM_INCREMENTAL_DETOKENIZATION_BACKEND == "TRTLLM"
|
||||
or stream_interval >= TLLM_STREAM_INTERVAL_THRESHOLD
|
||||
or spaces_between_special_tokens is False):
|
||||
return self.trtllm_decode_incrementally(
|
||||
token_ids,
|
||||
prev_text,
|
||||
states,
|
||||
flush=flush,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
spaces_between_special_tokens=spaces_between_special_tokens)
|
||||
else:
|
||||
return self.hf_decode_incrementally(
|
||||
token_ids,
|
||||
prev_text,
|
||||
states,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces)
|
||||
|
||||
def trtllm_decode_incrementally(
|
||||
self,
|
||||
token_ids: List[int],
|
||||
prev_text: Optional[str] = None,
|
||||
states: Optional[dict] = None,
|
||||
*,
|
||||
flush: bool = False,
|
||||
skip_special_tokens: bool = False,
|
||||
clean_up_tokenization_spaces: Optional[bool] = None,
|
||||
spaces_between_special_tokens: bool = True) -> Tuple[str, dict]:
|
||||
# Adapted from
|
||||
# https://github.com/vllm-project/vllm/blob/v0.6.3/vllm/transformers_utils/detokenizer.py#L238
|
||||
if prev_text is None:
|
||||
prev_text = ""
|
||||
|
||||
# HF incremental detokenization implementation is faster than TRTLLM
|
||||
# when stream_interval is smaller.
|
||||
if len(token_ids) < self.stream_interval_threshold:
|
||||
return self.hf_decode_incrementally(token_ids, prev_text,
|
||||
skip_special_tokens)
|
||||
|
||||
if states is None:
|
||||
states = {}
|
||||
last_new_tokens = states.pop('last_new_tokens', [])
|
||||
@ -193,24 +228,37 @@ class TransformersTokenizer(TokenizerBase):
|
||||
curr_new_text = self.clean_up_tokenization(curr_new_text)
|
||||
return prev_text + curr_new_text, {'last_new_tokens': pending_tokens}
|
||||
|
||||
@nvtx_range_debug("hf_decode_incrementally")
|
||||
def hf_decode_incrementally(self,
|
||||
token_ids: List[int],
|
||||
prev_text: Optional[str] = "",
|
||||
skip_special_tokens: bool = False) -> str:
|
||||
if self.hf_decode_stream is None:
|
||||
# Lazy initialize DecodeStream since it requires skip_special_tokens
|
||||
self.hf_decode_stream = DecodeStream(
|
||||
skip_special_tokens=skip_special_tokens)
|
||||
def hf_decode_incrementally(
|
||||
self,
|
||||
token_ids: List[int],
|
||||
prev_text: Optional[str] = None,
|
||||
states: Optional[dict] = None,
|
||||
*,
|
||||
skip_special_tokens: bool = False,
|
||||
clean_up_tokenization_spaces: Optional[bool] = None
|
||||
) -> Tuple[str, dict]:
|
||||
if states is None:
|
||||
states = {
|
||||
'decode_stream':
|
||||
DecodeStream(skip_special_tokens=skip_special_tokens)
|
||||
}
|
||||
|
||||
results = []
|
||||
for token_id in token_ids:
|
||||
result = self.hf_decode_stream.step(self.tokenizer._tokenizer,
|
||||
token_id)
|
||||
if result is not None:
|
||||
results.append(result)
|
||||
decode_stream = states.get('decode_stream')
|
||||
results = [
|
||||
result for tid in token_ids
|
||||
if (result := decode_stream.step(self.tokenizer._tokenizer, tid)
|
||||
) is not None
|
||||
]
|
||||
curr_new_text = "".join(results)
|
||||
if clean_up_tokenization_spaces is None:
|
||||
clean_up_tokenization_spaces = self.clean_up_tokenization_spaces
|
||||
if clean_up_tokenization_spaces:
|
||||
curr_new_text = self.clean_up_tokenization(curr_new_text)
|
||||
|
||||
return prev_text + "".join(results), None
|
||||
if prev_text is None:
|
||||
return curr_new_text, states
|
||||
else:
|
||||
return prev_text + curr_new_text, states
|
||||
|
||||
|
||||
def tokenizer_factory(obj: Optional[Union[str, Path, PreTrainedTokenizerBase,
|
||||
|
||||
@ -269,6 +269,9 @@ class SamplingParams:
|
||||
truncate_prompt_tokens: Optional[int] = None
|
||||
skip_special_tokens: bool = True
|
||||
spaces_between_special_tokens: bool = True
|
||||
# Currently, _stream_interval is only used to pass llm.args.stream_interval to tokenizer.
|
||||
# TODO: make this a per-request parameter.
|
||||
_stream_interval: Optional[int] = field(default=None, init=False, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.pad_id is None:
|
||||
|
||||
@ -60,14 +60,14 @@ class TestLlama3_1_8B(LlmapiAccuracyTestHarness):
|
||||
@pytest.mark.parametrize("stream_interval", [4, 64],
|
||||
ids=["stream_interval_4", "stream_interval_64"])
|
||||
def test_nvfp4_streaming(self, stream_interval):
|
||||
model_path = f"{llm_models_root()}/nvfp4-quantized/Meta-Llama-3.1-8B"
|
||||
|
||||
# When stream_interval < 32, hf incremental detokenization is used.
|
||||
# When stream_interval >= 32, trtllm implemented incremental detokenization is used.
|
||||
# When stream_interval < TLLM_STREAM_INTERVAL_THRESHOLD, hf incremental detokenization is used.
|
||||
# When stream_interval >= TLLM_STREAM_INTERVAL_THRESHOLD, trtllm implemented incremental detokenization is used.
|
||||
# The behavior is due to perf considerations, while both paths need to be tested.
|
||||
with LLM(model_path, stream_interval=stream_interval) as llm:
|
||||
with LLM(f"{llm_models_root()}/nvfp4-quantized/Meta-Llama-3.1-8B",
|
||||
stream_interval=stream_interval) as llm:
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
|
||||
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
|
||||
assert llm.args.stream_interval == stream_interval
|
||||
task = CnnDailymail(self.MODEL_NAME)
|
||||
task.evaluate(llm, streaming=True)
|
||||
|
||||
|
||||
@ -421,7 +421,6 @@ examples/test_gpt.py::test_starcoder_fp8_quantization_2gpu[starcoder] SKIP (http
|
||||
examples/test_gpt.py::test_starcoder_fp8_quantization_2gpu[starcoderplus] SKIP (https://nvbugs/5355128)
|
||||
examples/test_multimodal.py::test_llm_multimodal_general[fuyu-8b-pp:1-tp:1-float16-bs:8-cpp_e2e:True-nb:1] SKIP (https://nvbugs/5360086)
|
||||
examples/test_multimodal.py::test_llm_multimodal_general[llava-1.5-7b-hf-pp:1-tp:1-float16-bs:8-cpp_e2e:True-nb:1] SKIP (https://nvbugs/5360086)
|
||||
examples/test_multimodal.py::test_llm_multimodal_general[VILA1.5-3b-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5371538)
|
||||
test_e2e.py::test_trtllm_bench_llmapi_launch[trt_backend-llama-v3-llama3-8b] SKIP (https://nvbugs/5320234)
|
||||
examples/test_granite.py::test_granite_bf16_lora[granite-3.0-1b-a400m-instruct] SKIP (https://nvbugs/5374145)
|
||||
examples/test_multimodal.py::test_llm_multimodal_general[VILA1.5-3b-pp:1-tp:1-float16-bs:8-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5373451)
|
||||
|
||||
@ -49,8 +49,8 @@ def test_trtllm_sampler(model_path, test_case):
|
||||
"The capital of Bolivia is",
|
||||
]
|
||||
|
||||
expected_outputs = [["circumnavigation of the world."], [" Paris."],
|
||||
[" La Paz."]]
|
||||
expected_outputs = [["circumnavigation of the world."], ["Paris."],
|
||||
["La Paz."]]
|
||||
|
||||
# Test configuration
|
||||
max_new_tokens = test_case["max_new_tokens"]
|
||||
|
||||
@ -363,22 +363,31 @@ def test_llm_with_kv_cache_retention_config():
|
||||
print(output)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="https://nvbugs/5370718")
|
||||
@pytest.mark.parametrize('backend', ["HF", "TRTLLM"])
|
||||
@pytest.mark.parametrize(
|
||||
'tokenizer_dir, threshold',
|
||||
'tokenizer_dir, clean_up_tokenization_spaces, threshold',
|
||||
[
|
||||
(get_model_path('gpt2'), 0.95), # BPE
|
||||
(get_model_path('bert/bert-base-uncased'), 0.95), # WordPiece
|
||||
(get_model_path('t5-small'), 0.95), # SentencePiece
|
||||
(get_model_path('starcoder2-3b'), 0.95),
|
||||
(get_model_path('falcon-7b-instruct'), 0.95),
|
||||
(get_model_path('llama-models-v2/llama-v2-7b-hf'), 0.95),
|
||||
(get_model_path('codellama/CodeLlama-7b-Instruct-hf'), 0.95),
|
||||
(llama_model_path, 0.95),
|
||||
(get_model_path(mixtral_model_name), 0.95)
|
||||
(get_model_path('gpt2'), False, 0.95), # BPE
|
||||
(get_model_path('bert/bert-base-uncased'), True, 0.95), # WordPiece
|
||||
(get_model_path('t5-small'), True, 0.95), # SentencePiece
|
||||
(get_model_path('starcoder2-3b'), False, 0.95),
|
||||
(get_model_path('falcon-7b-instruct'), False, 0.95),
|
||||
(get_model_path('llama-models-v2/llama-v2-7b-hf'), False, 0.95),
|
||||
(get_model_path('codellama/CodeLlama-7b-Instruct-hf'), False, 0.95),
|
||||
(llama_model_path, False, 0.95),
|
||||
(get_model_path(mixtral_model_name), False, 0.95),
|
||||
(get_model_path('llama-3.1-model/Meta-Llama-3.1-8B'), False, 0.95),
|
||||
(get_model_path('DeepSeek-R1/DeepSeek-R1'), False, 0.95)
|
||||
])
|
||||
@pytest.mark.part0
|
||||
def test_tokenizer_decode_incrementally(tokenizer_dir: str, threshold: float):
|
||||
def test_tokenizer_decode_incrementally(tokenizer_dir: str,
|
||||
clean_up_tokenization_spaces: bool,
|
||||
threshold: float, backend: str, mocker):
|
||||
import tensorrt_llm.llmapi.tokenizer
|
||||
mocker.patch.object(tensorrt_llm.llmapi.tokenizer,
|
||||
"TLLM_INCREMENTAL_DETOKENIZATION_BACKEND", backend)
|
||||
assert tensorrt_llm.llmapi.tokenizer.TLLM_INCREMENTAL_DETOKENIZATION_BACKEND == backend
|
||||
|
||||
random.seed(42)
|
||||
|
||||
num_samples = 100
|
||||
@ -410,8 +419,7 @@ def test_tokenizer_decode_incrementally(tokenizer_dir: str, threshold: float):
|
||||
decoded_text, states = tokenizer.decode_incrementally(
|
||||
[token_ids[i]], decoded_text, states)
|
||||
|
||||
if tokenizer_dir.endswith(
|
||||
'bert-base-uncased') and tokenizer.clean_up_tokenization_spaces:
|
||||
if clean_up_tokenization_spaces and tokenizer.clean_up_tokenization_spaces:
|
||||
decoded_text = tokenizer.clean_up_tokenization(decoded_text)
|
||||
reference = tokenizer.decode(token_ids)
|
||||
if decoded_text == reference:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user