[https://nvbugs/5769815][fix] Fix offset calculation in _are_stop_words when using speculative decoding (#10854)

Signed-off-by: Stefan Niebler <82932102+stnie@users.noreply.github.com>
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
This commit is contained in:
Stefan Niebler 2026-01-26 16:21:59 +01:00 committed by Yanchao Lu
parent 6c4e0c3dbe
commit d50010cd1f
3 changed files with 15 additions and 3 deletions

View File

@ -3040,11 +3040,12 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
words_device = words.to("cuda", non_blocking=True)
draft_token_length = get_draft_token_length(request)
max_draft_token_length = self.max_tokens - 1
for beam_idx in range(self.max_beam_width):
new_tokens = padded_tokens[request_idx, beam_idx]
for step_idx in range(draft_token_length + 1):
size_per_step = new_tokens.size(0) - draft_token_length + step_idx
size_per_step = new_tokens.size(0) - max_draft_token_length + step_idx
matches = []
for word, L in zip(words_device, lens):
truncated_seq = new_tokens[size_per_step - L : size_per_step]

View File

@ -253,7 +253,6 @@ accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_vswa_reuse_4gpus[one_m
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_vswa_reuse_4gpus[two_model] SKIP (https://nvbugs/5756028)
accuracy/test_llm_api_pytorch.py::TestLlama3_2_1B::test_fp8_prequantized SKIP (https://nvbugs/5785465)
accuracy/test_llm_api_pytorch.py::TestMinistral8BInstruct::test_fp8 SKIP (https://nvbugs/5785485)
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_ngram SKIP (https://nvbugs/5769815)
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_eagle3_tp8[eagle3_one_model=True-torch_compile=False] SKIP (https://nvbugs/5787892)
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_eagle3_tp8[eagle3_one_model=False-torch_compile=False] SKIP (https://nvbugs/5787892)
accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_fp8_chunked_prefill[tp8ep8-cuda_graph=False] SKIP (https://nvbugs/5795918)

View File

@ -655,10 +655,13 @@ class RequestCase:
finish_reasons: list[FinishReason],
max_new_tokens: int = MAX_NEW_TOKENS,
end_id: Optional[int] = None,
num_draft_tokens: int | None = None,
stop_words_list: Optional[list[list[int]]] = None,
):
seq_slot = self.seq_slots.pop() # random seq slot in MAX_NUM_SEQUENCES
self.prompt = prompt
if num_draft_tokens is None:
num_draft_tokens = len(new_tokens) - 1
self.request = LlmRequest(
request_id=seq_slot,
seq_slot=seq_slot,
@ -670,7 +673,7 @@ class RequestCase:
end_id=end_id,
sampling_config=SamplingConfig(),
is_streaming=False,
draft_tokens=new_tokens[:-1],
draft_tokens=new_tokens[:num_draft_tokens],
)
assert len(new_tokens) == len(finish_reasons)
self.new_tokens = new_tokens
@ -767,6 +770,15 @@ def test_write_finish_reasons():
new_tokens=[12, 13, 60],
finish_reasons=[NOT_FINISHED, STOP_WORDS, NOT_FINISHED],
),
RequestCase(
prompt=[7, 8, 6],
stop_words_list=[[12, 13]],
new_tokens=[60, 12, 13],
# The request has stop words, but no draft is created
# Tokens at indices greater than 0 should be ignored
num_draft_tokens=0,
finish_reasons=[NOT_FINISHED, NOT_FINISHED, NOT_FINISHED],
),
RequestCase(
prompt=[1, 2, 3, 4],
end_id=99,