mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Performance] Add is_reasoning_end_streaming() override to GptOssReasoningParser (#35745)
Signed-off-by: Fergus <fergus.barratt00@gmail.com> Signed-off-by: fergus barratt <fergus.barratt00@gmail.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
This commit is contained in:
@@ -280,3 +280,72 @@ class TestGptOssStructuralTags:
|
||||
assert tag["content"]["type"] == "any_text"
|
||||
assert tag["end"] == "<|end|>"
|
||||
assert tag["begin"].startswith("<|channel|>")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"output, is_reasoning_end",
|
||||
[(t["output"], t["is_reasoning_end"]) for t in TEST_CASES],
|
||||
)
|
||||
def test_gptoss_is_reasoning_end_streaming(
|
||||
output,
|
||||
is_reasoning_end,
|
||||
gpt_oss_tokenizer,
|
||||
):
|
||||
"""Streaming override must agree with is_reasoning_end for all cases."""
|
||||
tokens = gpt_oss_tokenizer.tokenize(output)
|
||||
parser: ReasoningParser = GptOssReasoningParser(gpt_oss_tokenizer)
|
||||
output_ids = gpt_oss_tokenizer.convert_tokens_to_ids(tokens)
|
||||
delta_ids = output_ids[-1:] if output_ids else []
|
||||
actual = parser.is_reasoning_end_streaming(output_ids, delta_ids)
|
||||
assert is_reasoning_end == actual
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"output, is_reasoning_end",
|
||||
[(t["output"], t["is_reasoning_end"]) for t in TEST_CASES],
|
||||
)
|
||||
def test_gptoss_is_reasoning_end_streaming_long_prefix(
|
||||
output,
|
||||
is_reasoning_end,
|
||||
gpt_oss_tokenizer,
|
||||
):
|
||||
"""Windowing must produce correct results even with a long prefix."""
|
||||
tokens = gpt_oss_tokenizer.tokenize(output)
|
||||
parser: ReasoningParser = GptOssReasoningParser(gpt_oss_tokenizer)
|
||||
output_ids = gpt_oss_tokenizer.convert_tokens_to_ids(tokens)
|
||||
# Prepend 10k dummy reasoning tokens to simulate a long generation
|
||||
long_prefix = [1] * 10_000
|
||||
padded_ids = long_prefix + list(output_ids)
|
||||
delta_ids = output_ids[-1:] if output_ids else []
|
||||
actual = parser.is_reasoning_end_streaming(padded_ids, delta_ids)
|
||||
assert is_reasoning_end == actual
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"output, is_reasoning_end",
|
||||
[(t["output"], t["is_reasoning_end"]) for t in TEST_CASES],
|
||||
)
|
||||
def test_gptoss_is_reasoning_end_streaming_large_delta(
|
||||
output,
|
||||
is_reasoning_end,
|
||||
gpt_oss_tokenizer,
|
||||
):
|
||||
"""Simulate speculative decoding where the entire test sequence arrives
|
||||
as a single large delta appended after a long prefix. The window must
|
||||
expand to cover delta_ids so the end pattern is never missed."""
|
||||
tokens = gpt_oss_tokenizer.tokenize(output)
|
||||
parser: ReasoningParser = GptOssReasoningParser(gpt_oss_tokenizer)
|
||||
output_ids = gpt_oss_tokenizer.convert_tokens_to_ids(tokens)
|
||||
long_prefix = [1] * 10_000
|
||||
padded_ids = long_prefix + list(output_ids)
|
||||
# delta_ids = the entire test sequence (as if accepted in one spec step)
|
||||
delta_ids = list(output_ids)
|
||||
actual = parser.is_reasoning_end_streaming(padded_ids, delta_ids)
|
||||
assert is_reasoning_end == actual
|
||||
|
||||
|
||||
def test_gptoss_is_reasoning_end_streaming_signature(gpt_oss_tokenizer):
|
||||
"""Verify the method is callable with the expected signature."""
|
||||
parser = GptOssReasoningParser(gpt_oss_tokenizer)
|
||||
result = parser.is_reasoning_end_streaming([], [])
|
||||
assert result is False
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
@@ -112,6 +112,25 @@ class GptOssReasoningParser(ReasoningParser):
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_reasoning_end_streaming(
|
||||
self, input_ids: Sequence[int], delta_ids: Iterable[int]
|
||||
) -> bool:
|
||||
# The pattern window covers the end-of-reasoning marker itself.
|
||||
# We add len(delta_ids) so that under speculative decoding (where
|
||||
# a single step can accept many tokens) the entire accepted chunk
|
||||
# is always inside the scan region.
|
||||
delta_ids = tuple(delta_ids)
|
||||
pattern_len = (
|
||||
len(self.reasoning_end_token_ids_prefix)
|
||||
+ self.reasoning_max_num_between_tokens
|
||||
+ len(self.reasoning_end_token_ids_suffix)
|
||||
)
|
||||
window = pattern_len + len(delta_ids)
|
||||
n = len(input_ids)
|
||||
if n <= window:
|
||||
return self.is_reasoning_end(input_ids)
|
||||
return self.is_reasoning_end(input_ids[n - window :])
|
||||
|
||||
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
||||
_, content, _ = parse_chat_output(input_ids)
|
||||
if content is None:
|
||||
|
||||
Reference in New Issue
Block a user