[Frontend] Fix reasoning_tokens for text-based parsers in Responses API (#33513)

Signed-off-by: Jaeyeon Kim <anencore94@gmail.com>
This commit is contained in:
Jaeyeon Kim(김재연)
2026-02-19 08:16:41 +01:00
committed by GitHub
parent b6101d384d
commit 9681068cf9
7 changed files with 208 additions and 3 deletions
@@ -134,6 +134,53 @@ async def test_streaming_output_consistency(client: OpenAI, model_name: str):
)
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_streaming_reasoning_tokens_e2e(client: OpenAI, model_name: str):
"""Verify final usage includes reasoning_tokens in streaming mode."""
response = await client.responses.create(
model=model_name,
input="Compute 17 * 19 and explain briefly.",
reasoning={"effort": "low"},
temperature=0.0,
stream=True,
)
completed_event = None
async for event in response:
if event.type == "response.completed":
completed_event = event
assert completed_event is not None
assert completed_event.response.status == "completed"
assert completed_event.response.usage is not None
assert completed_event.response.usage.output_tokens_details is not None
assert completed_event.response.usage.output_tokens_details.reasoning_tokens > 0, (
"Expected reasoning_tokens > 0 for streamed Qwen3 response."
)
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_non_streaming_reasoning_tokens_e2e(client: OpenAI, model_name: str):
"""Verify usage includes reasoning_tokens in non-streaming mode."""
response = await client.responses.create(
model=model_name,
input="Compute 23 * 17 and explain briefly.",
reasoning={"effort": "low"},
temperature=0.0,
stream=False,
)
assert response is not None
assert response.status == "completed"
assert response.usage is not None
assert response.usage.output_tokens_details is not None
assert response.usage.output_tokens_details.reasoning_tokens > 0, (
"Expected reasoning_tokens > 0 for non-streamed Qwen3 response."
)
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_max_tokens(client: OpenAI, model_name: str):
@@ -13,9 +13,13 @@ from openai.types.responses.tool import (
Tool,
)
import vllm.envs as envs
from vllm.entrypoints.mcp.tool_server import ToolServer
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.responses.context import ConversationContext
from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse,
RequestResponseMetadata,
)
from vllm.entrypoints.openai.responses.context import ConversationContext, SimpleContext
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.entrypoints.openai.responses.serving import (
OpenAIServingResponses,
@@ -23,6 +27,8 @@ from vllm.entrypoints.openai.responses.serving import (
extract_tool_types,
)
from vllm.inputs.data import TokensPrompt
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import SamplingParams
class MockConversationContext(ConversationContext):
@@ -259,6 +265,87 @@ class TestValidateGeneratorInput:
assert isinstance(result, ErrorResponse)
@pytest.mark.asyncio
async def test_reasoning_tokens_counted_for_text_reasoning_model(monkeypatch):
"""Ensure reasoning_tokens usage is derived from thinking token spans."""
class FakeTokenizer:
def __init__(self):
self._vocab = {"<think>": 1, "</think>": 2, "reason": 3, "final": 4}
def get_vocab(self):
return self._vocab
# Force non-harmony, SimpleContext path
monkeypatch.setattr(envs, "VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT", False)
engine_client = MagicMock()
model_config = MagicMock()
model_config.hf_config.model_type = "test"
model_config.hf_text_config = MagicMock()
model_config.get_diff_sampling_param.return_value = {}
engine_client.model_config = model_config
engine_client.input_processor = MagicMock()
engine_client.io_processor = MagicMock()
engine_client.renderer = MagicMock()
tokenizer = FakeTokenizer()
engine_client.renderer.get_tokenizer.return_value = tokenizer
models = MagicMock()
serving = OpenAIServingResponses(
engine_client=engine_client,
models=models,
request_logger=None,
chat_template=None,
chat_template_content_format="auto",
reasoning_parser="qwen3",
)
# Build a SimpleContext with thinking tokens in the output.
context = SimpleContext()
token_ids = [1, 10, 2, 20] # <think> 10 </think> 20 -> reasoning token count = 1
completion = CompletionOutput(
index=0,
text="<think>reason</think>final",
token_ids=token_ids,
cumulative_logprob=0.0,
logprobs=None,
finish_reason="stop",
stop_reason=None,
)
req_output = RequestOutput(
request_id="req",
prompt="hi",
prompt_token_ids=[7, 8],
prompt_logprobs=None,
outputs=[completion],
finished=True,
num_cached_tokens=0,
)
context.append_output(req_output)
async def dummy_result_generator():
yield None
request = ResponsesRequest(input="hi", tools=[], stream=False)
sampling_params = SamplingParams(max_tokens=16)
metadata = RequestResponseMetadata(request_id="req")
response = await serving.responses_full_generator(
request=request,
sampling_params=sampling_params,
result_generator=dummy_result_generator(),
context=context,
model_name="test-model",
tokenizer=tokenizer,
request_metadata=metadata,
)
assert response.usage.output_tokens_details.reasoning_tokens == 1
class TestExtractAllowedToolsFromMcpRequests:
"""Test class for _extract_allowed_tools_from_mcp_requests function"""
@@ -167,6 +167,23 @@ class TestBaseThinkingReasoningParserMethods:
is False
)
def test_count_reasoning_tokens(self, test_tokenizer):
"""Count tokens between start/end markers."""
parser = TestThinkingReasoningParser(test_tokenizer)
start = parser.start_token_id
end = parser.end_token_id
token_ids = [0, start, 11, 12, end, 99]
assert parser.count_reasoning_tokens(token_ids) == 2
def test_count_reasoning_tokens_nested(self, test_tokenizer):
"""Ensure nested thinking spans count all inner tokens safely."""
parser = TestThinkingReasoningParser(test_tokenizer)
s = parser.start_token_id
e = parser.end_token_id
token_ids = [s, 1, s, 2, e, 3, e]
# Tokens 1,2,3 are inside reasoning (depth>0) => 3 tokens
assert parser.count_reasoning_tokens(token_ids) == 3
def test_extract_content_ids(self, test_tokenizer):
"""Test the extract_content_ids method."""
parser = TestThinkingReasoningParser(test_tokenizer)
+3 -1
View File
@@ -280,7 +280,6 @@ class ParsableContext(ConversationContext):
self.num_prompt_tokens = 0
self.num_output_tokens = 0
self.num_cached_tokens = 0
# TODO: num_reasoning_tokens is not implemented yet.
self.num_reasoning_tokens = 0
# not implemented yet for ParsableContext
self.all_turn_metrics: list[TurnMetrics] = []
@@ -308,12 +307,15 @@ class ParsableContext(ConversationContext):
self.input_messages: list[ResponseRawMessageAndToken] = []
self.output_messages: list[ResponseRawMessageAndToken] = []
self._accumulated_token_ids: list[int] = []
def append_output(self, output: RequestOutput) -> None:
self.num_prompt_tokens = len(output.prompt_token_ids or [])
self.num_cached_tokens = output.num_cached_tokens or 0
self.num_output_tokens += len(output.outputs[0].token_ids or [])
self.parser.process(output.outputs[0])
output_token_ids = output.outputs[0].token_ids or []
self._accumulated_token_ids.extend(output_token_ids)
# only store if enable_response_messages is True, save memory
if self.request.enable_response_messages:
@@ -759,6 +759,19 @@ class OpenAIServingResponses(OpenAIServing):
num_generated_tokens = context.num_output_tokens
num_cached_tokens = context.num_cached_tokens
num_reasoning_tokens = context.num_reasoning_tokens
# For text-based reasoning parsers (e.g., <think>...</think>),
# HarmonyContext already counts reasoning tokens via channels.
# For Simple/Parsable contexts, derive reasoning_tokens from
# accumulated output token IDs using the parser if not already set.
if (
num_reasoning_tokens == 0
and self.parser is not None
and self.parser.reasoning_parser_cls is not None
and isinstance(context, (SimpleContext, ParsableContext))
):
reasoning_parser = self.parser.reasoning_parser_cls(tokenizer)
accumulated = getattr(context, "_accumulated_token_ids", []) or []
num_reasoning_tokens = reasoning_parser.count_reasoning_tokens(accumulated)
usage = ResponseUsage(
input_tokens=num_prompt_tokens,
+19
View File
@@ -104,6 +104,25 @@ class ReasoningParser:
The extracted content from the input_ids.
"""
def count_reasoning_tokens(self, token_ids: Sequence[int]) -> int:
"""Count the number of reasoning tokens in a sequence.
Text-based reasoning models typically wrap their chain-of-thought
between special start/end tokens (e.g., ``<think> ... </think>``).
Implementations that support reasoning token counting should override
this method. The default implementation returns ``0`` so existing
parsers remain unchanged unless they explicitly opt in.
Args:
token_ids: Sequence of generated token ids (excluding prompt).
Returns:
int: Number of tokens that belong to reasoning content.
"""
# By default, assume the parser cannot detect reasoning spans.
return 0
@abstractmethod
def extract_reasoning(
self,
+20
View File
@@ -175,3 +175,23 @@ class BaseThinkingReasoningParser(ReasoningParser):
# If generation stops right after end-of-think, return null content
final_content = content or None
return reasoning, final_content
def count_reasoning_tokens(self, token_ids: Sequence[int]) -> int:
"""Count tokens that fall within start/end thinking markers.
Uses a depth counter so nested spans are handled safely and stray end
tokens do not drive the counter negative.
"""
count = 0
depth = 0
for token_id in token_ids:
if token_id == self.start_token_id:
depth += 1
continue
if token_id == self.end_token_id:
if depth > 0:
depth -= 1
continue
if depth > 0:
count += 1
return count