diff --git a/tensorrt_llm/executor/result.py b/tensorrt_llm/executor/result.py index abd1a8649d..9cd539f33b 100644 --- a/tensorrt_llm/executor/result.py +++ b/tensorrt_llm/executor/result.py @@ -376,17 +376,16 @@ class DetokenizedGenerationResultBase(GenerationResultBase): if self.sampling_params.detokenize and self.tokenizer is not None: for beam_output in self.outputs: beam_output._last_text_len = len(beam_output.text) - if hasattr(self.tokenizer, 'decode_incrementally'): - if self._streaming and not self.sampling_params.use_beam_search: - beam_output.text, beam_output._incremental_states = self.tokenizer.decode_incrementally( - beam_output.token_ids_diff, - prev_text=beam_output.text, - states=beam_output._incremental_states, - flush=self._done, - **kwargs) - else: - beam_output.text, _ = self.tokenizer.decode_incrementally( - beam_output.token_ids, flush=self._done, **kwargs) + if hasattr( + self.tokenizer, 'decode_incrementally' + ) and self._streaming and not self.sampling_params.use_beam_search: + beam_output.text, beam_output._incremental_states = self.tokenizer.decode_incrementally( + beam_output.token_ids_diff, + prev_text=beam_output.text, + states=beam_output._incremental_states, + flush=self._done, + stream_interval=self.sampling_params._stream_interval, + **kwargs) else: beam_output.text = self.tokenizer.decode( beam_output.token_ids, **kwargs) diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 6f3adcbda2..9cd606e322 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -499,8 +499,8 @@ class BaseLLM: raise ValueError( "tokenizer is required to initialize a default sampling_params, or you can explicitly specify a sampling_params" ) - return SamplingParams(end_id=self.tokenizer.eos_token_id, - pad_id=self.tokenizer.pad_token_id) + sampling_params = SamplingParams(end_id=self.tokenizer.eos_token_id, + pad_id=self.tokenizer.pad_token_id) elif isinstance(sampling_params, SamplingParams): if sampling_params.end_id is None: if self.tokenizer is None: @@ -508,21 +508,26 @@ class BaseLLM: "tokenizer is required to reset end_id if it is None, or you can explicitly specify the end_id for sampling_params" ) sampling_params._setup(self.tokenizer) - # auto enabled context and/or generation logits flags, as they are required by logprob computation for TRT backend. - if self.args.backend not in ["pytorch", "_autodeploy"]: - if sampling_params.prompt_logprobs and not sampling_params.return_context_logits: - sampling_params.return_context_logits = True - sampling_params._context_logits_auto_enabled = True - if sampling_params.logprobs and not sampling_params.return_generation_logits: - sampling_params.return_generation_logits = True - sampling_params._generation_logits_auto_enabled = True - - return sampling_params else: raise TypeError( f"The sampling_params must be type SamplingParams or None, but got {type(sampling_params)}" ) + # auto enabled context and/or generation logits flags, as they are required by logprob computation for TRT backend. + if self.args.backend not in ["pytorch", "_autodeploy"]: + if sampling_params.prompt_logprobs and not sampling_params.return_context_logits: + sampling_params.return_context_logits = True + sampling_params._context_logits_auto_enabled = True + if sampling_params.logprobs and not sampling_params.return_generation_logits: + sampling_params.return_generation_logits = True + sampling_params._generation_logits_auto_enabled = True + + if sampling_params._stream_interval is None: + sampling_params._stream_interval = getattr(self.args, + "stream_interval", 1) + + return sampling_params + def _check_arguments(self, prompt_len: int, query_len: int, sampling_params: SamplingParams) -> None: diff --git a/tensorrt_llm/llmapi/tokenizer.py b/tensorrt_llm/llmapi/tokenizer.py index 9943338528..858f98289c 100644 --- a/tensorrt_llm/llmapi/tokenizer.py +++ b/tensorrt_llm/llmapi/tokenizer.py @@ -2,11 +2,23 @@ import os from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union -from tokenizers.decoders import DecodeStream from transformers import (AutoTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast) from .._utils import nvtx_range_debug +from ..logger import logger + +TLLM_INCREMENTAL_DETOKENIZATION_BACKEND = os.environ.get( + "TLLM_INCREMENTAL_DETOKENIZATION_BACKEND", "HF") +TLLM_STREAM_INTERVAL_THRESHOLD = int( + os.environ.get("TLLM_STREAM_INTERVAL_THRESHOLD", "24")) +try: + from tokenizers.decoders import DecodeStream # noqa +except ImportError: + logger.warning( + f"HF incremental detokenization is unsupported by tokenizer<0.21.0; fallback to TRTLLM incremental detokenization." + ) + TLLM_INCREMENTAL_DETOKENIZATION_BACKEND = "TRTLLM" class TokenizerBase(PreTrainedTokenizerBase): @@ -20,9 +32,6 @@ class TransformersTokenizer(TokenizerBase): def __init__(self, tokenizer): self.tokenizer = tokenizer self._all_special_tokens_set = set(self.tokenizer.all_special_tokens) - self.hf_decode_stream = None - self.stream_interval_threshold = int( - os.getenv("TLLM_STREAM_INTERVAL_THRESHOLD", "32")) def __call__(self, text: str, *args, **kwargs) -> Any: return self.tokenizer(text, *args, **kwargs) @@ -129,8 +138,9 @@ class TransformersTokenizer(TokenizerBase): *, flush: bool = False, skip_special_tokens: bool = False, - clean_up_tokenization_spaces: bool = None, - spaces_between_special_tokens: bool = True) -> Tuple[str, dict]: + clean_up_tokenization_spaces: Optional[bool] = None, + spaces_between_special_tokens: bool = True, + stream_interval: int = 1) -> Tuple[str, dict]: """Incremental detokenization, typically used for streaming generation. Args: @@ -141,22 +151,47 @@ class TransformersTokenizer(TokenizerBase): skip_special_tokens (bool): Whether to remove special tokens in the decoding. clean_up_tokenization_spaces (bool): Whether to clean up tokenization spaces. spaces_between_special_tokens (bool): Whether to add spaces between special tokens. + stream_interval (int): The iteration interval to create responses under the streaming mode. Returns: text, states (Tuple[str, dict]): text is the current decoded text, states is the current incremental detokenization states. They should be passed to next incremental detokenization iteration, if any. """ + # HF incremental detokenization implementation is faster than TRTLLM when stream_interval is smaller. + if (TLLM_INCREMENTAL_DETOKENIZATION_BACKEND == "TRTLLM" + or stream_interval >= TLLM_STREAM_INTERVAL_THRESHOLD + or spaces_between_special_tokens is False): + return self.trtllm_decode_incrementally( + token_ids, + prev_text, + states, + flush=flush, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + spaces_between_special_tokens=spaces_between_special_tokens) + else: + return self.hf_decode_incrementally( + token_ids, + prev_text, + states, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces) + + def trtllm_decode_incrementally( + self, + token_ids: List[int], + prev_text: Optional[str] = None, + states: Optional[dict] = None, + *, + flush: bool = False, + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: Optional[bool] = None, + spaces_between_special_tokens: bool = True) -> Tuple[str, dict]: # Adapted from # https://github.com/vllm-project/vllm/blob/v0.6.3/vllm/transformers_utils/detokenizer.py#L238 if prev_text is None: prev_text = "" - # HF incremental detokenization implementation is faster than TRTLLM - # when stream_interval is smaller. - if len(token_ids) < self.stream_interval_threshold: - return self.hf_decode_incrementally(token_ids, prev_text, - skip_special_tokens) - if states is None: states = {} last_new_tokens = states.pop('last_new_tokens', []) @@ -193,24 +228,37 @@ class TransformersTokenizer(TokenizerBase): curr_new_text = self.clean_up_tokenization(curr_new_text) return prev_text + curr_new_text, {'last_new_tokens': pending_tokens} - @nvtx_range_debug("hf_decode_incrementally") - def hf_decode_incrementally(self, - token_ids: List[int], - prev_text: Optional[str] = "", - skip_special_tokens: bool = False) -> str: - if self.hf_decode_stream is None: - # Lazy initialize DecodeStream since it requires skip_special_tokens - self.hf_decode_stream = DecodeStream( - skip_special_tokens=skip_special_tokens) + def hf_decode_incrementally( + self, + token_ids: List[int], + prev_text: Optional[str] = None, + states: Optional[dict] = None, + *, + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: Optional[bool] = None + ) -> Tuple[str, dict]: + if states is None: + states = { + 'decode_stream': + DecodeStream(skip_special_tokens=skip_special_tokens) + } - results = [] - for token_id in token_ids: - result = self.hf_decode_stream.step(self.tokenizer._tokenizer, - token_id) - if result is not None: - results.append(result) + decode_stream = states.get('decode_stream') + results = [ + result for tid in token_ids + if (result := decode_stream.step(self.tokenizer._tokenizer, tid) + ) is not None + ] + curr_new_text = "".join(results) + if clean_up_tokenization_spaces is None: + clean_up_tokenization_spaces = self.clean_up_tokenization_spaces + if clean_up_tokenization_spaces: + curr_new_text = self.clean_up_tokenization(curr_new_text) - return prev_text + "".join(results), None + if prev_text is None: + return curr_new_text, states + else: + return prev_text + curr_new_text, states def tokenizer_factory(obj: Optional[Union[str, Path, PreTrainedTokenizerBase, diff --git a/tensorrt_llm/sampling_params.py b/tensorrt_llm/sampling_params.py index 42ccf02f13..c2ac3b881d 100644 --- a/tensorrt_llm/sampling_params.py +++ b/tensorrt_llm/sampling_params.py @@ -269,6 +269,9 @@ class SamplingParams: truncate_prompt_tokens: Optional[int] = None skip_special_tokens: bool = True spaces_between_special_tokens: bool = True + # Currently, _stream_interval is only used to pass llm.args.stream_interval to tokenizer. + # TODO: make this a per-request parameter. + _stream_interval: Optional[int] = field(default=None, init=False, repr=False) def __post_init__(self): if self.pad_id is None: diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 6ffe10e698..27d9c539c1 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -60,14 +60,14 @@ class TestLlama3_1_8B(LlmapiAccuracyTestHarness): @pytest.mark.parametrize("stream_interval", [4, 64], ids=["stream_interval_4", "stream_interval_64"]) def test_nvfp4_streaming(self, stream_interval): - model_path = f"{llm_models_root()}/nvfp4-quantized/Meta-Llama-3.1-8B" - - # When stream_interval < 32, hf incremental detokenization is used. - # When stream_interval >= 32, trtllm implemented incremental detokenization is used. + # When stream_interval < TLLM_STREAM_INTERVAL_THRESHOLD, hf incremental detokenization is used. + # When stream_interval >= TLLM_STREAM_INTERVAL_THRESHOLD, trtllm implemented incremental detokenization is used. # The behavior is due to perf considerations, while both paths need to be tested. - with LLM(model_path, stream_interval=stream_interval) as llm: + with LLM(f"{llm_models_root()}/nvfp4-quantized/Meta-Llama-3.1-8B", + stream_interval=stream_interval) as llm: assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4 assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8 + assert llm.args.stream_interval == stream_interval task = CnnDailymail(self.MODEL_NAME) task.evaluate(llm, streaming=True) diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 2395efe8be..b087953450 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -421,7 +421,6 @@ examples/test_gpt.py::test_starcoder_fp8_quantization_2gpu[starcoder] SKIP (http examples/test_gpt.py::test_starcoder_fp8_quantization_2gpu[starcoderplus] SKIP (https://nvbugs/5355128) examples/test_multimodal.py::test_llm_multimodal_general[fuyu-8b-pp:1-tp:1-float16-bs:8-cpp_e2e:True-nb:1] SKIP (https://nvbugs/5360086) examples/test_multimodal.py::test_llm_multimodal_general[llava-1.5-7b-hf-pp:1-tp:1-float16-bs:8-cpp_e2e:True-nb:1] SKIP (https://nvbugs/5360086) -examples/test_multimodal.py::test_llm_multimodal_general[VILA1.5-3b-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5371538) test_e2e.py::test_trtllm_bench_llmapi_launch[trt_backend-llama-v3-llama3-8b] SKIP (https://nvbugs/5320234) examples/test_granite.py::test_granite_bf16_lora[granite-3.0-1b-a400m-instruct] SKIP (https://nvbugs/5374145) examples/test_multimodal.py::test_llm_multimodal_general[VILA1.5-3b-pp:1-tp:1-float16-bs:8-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5373451) diff --git a/tests/unittest/_torch/test_trtllm_sampler.py b/tests/unittest/_torch/test_trtllm_sampler.py index 914d598f0e..e8d6b2f9d8 100644 --- a/tests/unittest/_torch/test_trtllm_sampler.py +++ b/tests/unittest/_torch/test_trtllm_sampler.py @@ -49,8 +49,8 @@ def test_trtllm_sampler(model_path, test_case): "The capital of Bolivia is", ] - expected_outputs = [["circumnavigation of the world."], [" Paris."], - [" La Paz."]] + expected_outputs = [["circumnavigation of the world."], ["Paris."], + ["La Paz."]] # Test configuration max_new_tokens = test_case["max_new_tokens"] diff --git a/tests/unittest/llmapi/test_llm.py b/tests/unittest/llmapi/test_llm.py index ad87291aad..206f6e1c23 100644 --- a/tests/unittest/llmapi/test_llm.py +++ b/tests/unittest/llmapi/test_llm.py @@ -363,22 +363,31 @@ def test_llm_with_kv_cache_retention_config(): print(output) -@pytest.mark.skip(reason="https://nvbugs/5370718") +@pytest.mark.parametrize('backend', ["HF", "TRTLLM"]) @pytest.mark.parametrize( - 'tokenizer_dir, threshold', + 'tokenizer_dir, clean_up_tokenization_spaces, threshold', [ - (get_model_path('gpt2'), 0.95), # BPE - (get_model_path('bert/bert-base-uncased'), 0.95), # WordPiece - (get_model_path('t5-small'), 0.95), # SentencePiece - (get_model_path('starcoder2-3b'), 0.95), - (get_model_path('falcon-7b-instruct'), 0.95), - (get_model_path('llama-models-v2/llama-v2-7b-hf'), 0.95), - (get_model_path('codellama/CodeLlama-7b-Instruct-hf'), 0.95), - (llama_model_path, 0.95), - (get_model_path(mixtral_model_name), 0.95) + (get_model_path('gpt2'), False, 0.95), # BPE + (get_model_path('bert/bert-base-uncased'), True, 0.95), # WordPiece + (get_model_path('t5-small'), True, 0.95), # SentencePiece + (get_model_path('starcoder2-3b'), False, 0.95), + (get_model_path('falcon-7b-instruct'), False, 0.95), + (get_model_path('llama-models-v2/llama-v2-7b-hf'), False, 0.95), + (get_model_path('codellama/CodeLlama-7b-Instruct-hf'), False, 0.95), + (llama_model_path, False, 0.95), + (get_model_path(mixtral_model_name), False, 0.95), + (get_model_path('llama-3.1-model/Meta-Llama-3.1-8B'), False, 0.95), + (get_model_path('DeepSeek-R1/DeepSeek-R1'), False, 0.95) ]) @pytest.mark.part0 -def test_tokenizer_decode_incrementally(tokenizer_dir: str, threshold: float): +def test_tokenizer_decode_incrementally(tokenizer_dir: str, + clean_up_tokenization_spaces: bool, + threshold: float, backend: str, mocker): + import tensorrt_llm.llmapi.tokenizer + mocker.patch.object(tensorrt_llm.llmapi.tokenizer, + "TLLM_INCREMENTAL_DETOKENIZATION_BACKEND", backend) + assert tensorrt_llm.llmapi.tokenizer.TLLM_INCREMENTAL_DETOKENIZATION_BACKEND == backend + random.seed(42) num_samples = 100 @@ -410,8 +419,7 @@ def test_tokenizer_decode_incrementally(tokenizer_dir: str, threshold: float): decoded_text, states = tokenizer.decode_incrementally( [token_ids[i]], decoded_text, states) - if tokenizer_dir.endswith( - 'bert-base-uncased') and tokenizer.clean_up_tokenization_spaces: + if clean_up_tokenization_spaces and tokenizer.clean_up_tokenization_spaces: decoded_text = tokenizer.clean_up_tokenization(decoded_text) reference = tokenizer.decode(token_ids) if decoded_text == reference: