From d50010cd1f3fb2f67ee388dd2aea5bfa15a4ac0e Mon Sep 17 00:00:00 2001 From: Stefan Niebler <82932102+stnie@users.noreply.github.com> Date: Mon, 26 Jan 2026 16:21:59 +0100 Subject: [PATCH] [https://nvbugs/5769815][fix] Fix offset calculation in _are_stop_words when using speculative decoding (#10854) Signed-off-by: Stefan Niebler <82932102+stnie@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/sampler.py | 3 ++- tests/integration/test_lists/waives.txt | 1 - .../unittest/_torch/sampler/test_torch_sampler.py | 14 +++++++++++++- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 9fdaf22fb1..da33c6cb30 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -3040,11 +3040,12 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin): words_device = words.to("cuda", non_blocking=True) draft_token_length = get_draft_token_length(request) + max_draft_token_length = self.max_tokens - 1 for beam_idx in range(self.max_beam_width): new_tokens = padded_tokens[request_idx, beam_idx] for step_idx in range(draft_token_length + 1): - size_per_step = new_tokens.size(0) - draft_token_length + step_idx + size_per_step = new_tokens.size(0) - max_draft_token_length + step_idx matches = [] for word, L in zip(words_device, lens): truncated_seq = new_tokens[size_per_step - L : size_per_step] diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 3d36e4a2de..00d88b7686 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -253,7 +253,6 @@ accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_vswa_reuse_4gpus[one_m accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_vswa_reuse_4gpus[two_model] SKIP (https://nvbugs/5756028) accuracy/test_llm_api_pytorch.py::TestLlama3_2_1B::test_fp8_prequantized SKIP (https://nvbugs/5785465) accuracy/test_llm_api_pytorch.py::TestMinistral8BInstruct::test_fp8 SKIP (https://nvbugs/5785485) -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_ngram SKIP (https://nvbugs/5769815) accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_eagle3_tp8[eagle3_one_model=True-torch_compile=False] SKIP (https://nvbugs/5787892) accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_eagle3_tp8[eagle3_one_model=False-torch_compile=False] SKIP (https://nvbugs/5787892) accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_fp8_chunked_prefill[tp8ep8-cuda_graph=False] SKIP (https://nvbugs/5795918) diff --git a/tests/unittest/_torch/sampler/test_torch_sampler.py b/tests/unittest/_torch/sampler/test_torch_sampler.py index 0dc5dc0ad5..9f64b58771 100644 --- a/tests/unittest/_torch/sampler/test_torch_sampler.py +++ b/tests/unittest/_torch/sampler/test_torch_sampler.py @@ -655,10 +655,13 @@ class RequestCase: finish_reasons: list[FinishReason], max_new_tokens: int = MAX_NEW_TOKENS, end_id: Optional[int] = None, + num_draft_tokens: int | None = None, stop_words_list: Optional[list[list[int]]] = None, ): seq_slot = self.seq_slots.pop() # random seq slot in MAX_NUM_SEQUENCES self.prompt = prompt + if num_draft_tokens is None: + num_draft_tokens = len(new_tokens) - 1 self.request = LlmRequest( request_id=seq_slot, seq_slot=seq_slot, @@ -670,7 +673,7 @@ class RequestCase: end_id=end_id, sampling_config=SamplingConfig(), is_streaming=False, - draft_tokens=new_tokens[:-1], + draft_tokens=new_tokens[:num_draft_tokens], ) assert len(new_tokens) == len(finish_reasons) self.new_tokens = new_tokens @@ -767,6 +770,15 @@ def test_write_finish_reasons(): new_tokens=[12, 13, 60], finish_reasons=[NOT_FINISHED, STOP_WORDS, NOT_FINISHED], ), + RequestCase( + prompt=[7, 8, 6], + stop_words_list=[[12, 13]], + new_tokens=[60, 12, 13], + # The request has stop words, but no draft is created + # Tokens at indices greater than 0 should be ignored + num_draft_tokens=0, + finish_reasons=[NOT_FINISHED, NOT_FINISHED, NOT_FINISHED], + ), RequestCase( prompt=[1, 2, 3, 4], end_id=99,