[Performance] Add is_reasoning_end_streaming() override to GptOssReasoningParser (#35745)

Signed-off-by: Fergus <fergus.barratt00@gmail.com>
Signed-off-by: fergus barratt <fergus.barratt00@gmail.com>
Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
This commit is contained in:
Fergus
2026-04-21 19:31:27 +01:00
committed by GitHub
parent 9f39b380d0
commit 5544f8c18b
2 changed files with 89 additions and 1 deletions
@@ -280,3 +280,72 @@ class TestGptOssStructuralTags:
assert tag["content"]["type"] == "any_text"
assert tag["end"] == "<|end|>"
assert tag["begin"].startswith("<|channel|>")
@pytest.mark.parametrize(
"output, is_reasoning_end",
[(t["output"], t["is_reasoning_end"]) for t in TEST_CASES],
)
def test_gptoss_is_reasoning_end_streaming(
output,
is_reasoning_end,
gpt_oss_tokenizer,
):
"""Streaming override must agree with is_reasoning_end for all cases."""
tokens = gpt_oss_tokenizer.tokenize(output)
parser: ReasoningParser = GptOssReasoningParser(gpt_oss_tokenizer)
output_ids = gpt_oss_tokenizer.convert_tokens_to_ids(tokens)
delta_ids = output_ids[-1:] if output_ids else []
actual = parser.is_reasoning_end_streaming(output_ids, delta_ids)
assert is_reasoning_end == actual
@pytest.mark.parametrize(
"output, is_reasoning_end",
[(t["output"], t["is_reasoning_end"]) for t in TEST_CASES],
)
def test_gptoss_is_reasoning_end_streaming_long_prefix(
output,
is_reasoning_end,
gpt_oss_tokenizer,
):
"""Windowing must produce correct results even with a long prefix."""
tokens = gpt_oss_tokenizer.tokenize(output)
parser: ReasoningParser = GptOssReasoningParser(gpt_oss_tokenizer)
output_ids = gpt_oss_tokenizer.convert_tokens_to_ids(tokens)
# Prepend 10k dummy reasoning tokens to simulate a long generation
long_prefix = [1] * 10_000
padded_ids = long_prefix + list(output_ids)
delta_ids = output_ids[-1:] if output_ids else []
actual = parser.is_reasoning_end_streaming(padded_ids, delta_ids)
assert is_reasoning_end == actual
@pytest.mark.parametrize(
"output, is_reasoning_end",
[(t["output"], t["is_reasoning_end"]) for t in TEST_CASES],
)
def test_gptoss_is_reasoning_end_streaming_large_delta(
output,
is_reasoning_end,
gpt_oss_tokenizer,
):
"""Simulate speculative decoding where the entire test sequence arrives
as a single large delta appended after a long prefix. The window must
expand to cover delta_ids so the end pattern is never missed."""
tokens = gpt_oss_tokenizer.tokenize(output)
parser: ReasoningParser = GptOssReasoningParser(gpt_oss_tokenizer)
output_ids = gpt_oss_tokenizer.convert_tokens_to_ids(tokens)
long_prefix = [1] * 10_000
padded_ids = long_prefix + list(output_ids)
# delta_ids = the entire test sequence (as if accepted in one spec step)
delta_ids = list(output_ids)
actual = parser.is_reasoning_end_streaming(padded_ids, delta_ids)
assert is_reasoning_end == actual
def test_gptoss_is_reasoning_end_streaming_signature(gpt_oss_tokenizer):
"""Verify the method is callable with the expected signature."""
parser = GptOssReasoningParser(gpt_oss_tokenizer)
result = parser.is_reasoning_end_streaming([], [])
assert result is False
+20 -1
View File
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Sequence
from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING
from transformers import PreTrainedTokenizerBase
@@ -112,6 +112,25 @@ class GptOssReasoningParser(ReasoningParser):
return True
return False
def is_reasoning_end_streaming(
self, input_ids: Sequence[int], delta_ids: Iterable[int]
) -> bool:
# The pattern window covers the end-of-reasoning marker itself.
# We add len(delta_ids) so that under speculative decoding (where
# a single step can accept many tokens) the entire accepted chunk
# is always inside the scan region.
delta_ids = tuple(delta_ids)
pattern_len = (
len(self.reasoning_end_token_ids_prefix)
+ self.reasoning_max_num_between_tokens
+ len(self.reasoning_end_token_ids_suffix)
)
window = pattern_len + len(delta_ids)
n = len(input_ids)
if n <= window:
return self.is_reasoning_end(input_ids)
return self.is_reasoning_end(input_ids[n - window :])
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
_, content, _ = parse_chat_output(input_ids)
if content is None: