mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Frontend] Fix reasoning_tokens for text-based parsers in Responses API (#33513)
Signed-off-by: Jaeyeon Kim <anencore94@gmail.com>
This commit is contained in:
@@ -134,6 +134,53 @@ async def test_streaming_output_consistency(client: OpenAI, model_name: str):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_streaming_reasoning_tokens_e2e(client: OpenAI, model_name: str):
|
||||
"""Verify final usage includes reasoning_tokens in streaming mode."""
|
||||
response = await client.responses.create(
|
||||
model=model_name,
|
||||
input="Compute 17 * 19 and explain briefly.",
|
||||
reasoning={"effort": "low"},
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
completed_event = None
|
||||
async for event in response:
|
||||
if event.type == "response.completed":
|
||||
completed_event = event
|
||||
|
||||
assert completed_event is not None
|
||||
assert completed_event.response.status == "completed"
|
||||
assert completed_event.response.usage is not None
|
||||
assert completed_event.response.usage.output_tokens_details is not None
|
||||
assert completed_event.response.usage.output_tokens_details.reasoning_tokens > 0, (
|
||||
"Expected reasoning_tokens > 0 for streamed Qwen3 response."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_non_streaming_reasoning_tokens_e2e(client: OpenAI, model_name: str):
|
||||
"""Verify usage includes reasoning_tokens in non-streaming mode."""
|
||||
response = await client.responses.create(
|
||||
model=model_name,
|
||||
input="Compute 23 * 17 and explain briefly.",
|
||||
reasoning={"effort": "low"},
|
||||
temperature=0.0,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert response.status == "completed"
|
||||
assert response.usage is not None
|
||||
assert response.usage.output_tokens_details is not None
|
||||
assert response.usage.output_tokens_details.reasoning_tokens > 0, (
|
||||
"Expected reasoning_tokens > 0 for non-streamed Qwen3 response."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_max_tokens(client: OpenAI, model_name: str):
|
||||
|
||||
@@ -13,9 +13,13 @@ from openai.types.responses.tool import (
|
||||
Tool,
|
||||
)
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.entrypoints.mcp.tool_server import ToolServer
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.responses.context import ConversationContext
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
ErrorResponse,
|
||||
RequestResponseMetadata,
|
||||
)
|
||||
from vllm.entrypoints.openai.responses.context import ConversationContext, SimpleContext
|
||||
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
||||
from vllm.entrypoints.openai.responses.serving import (
|
||||
OpenAIServingResponses,
|
||||
@@ -23,6 +27,8 @@ from vllm.entrypoints.openai.responses.serving import (
|
||||
extract_tool_types,
|
||||
)
|
||||
from vllm.inputs.data import TokensPrompt
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
|
||||
class MockConversationContext(ConversationContext):
|
||||
@@ -259,6 +265,87 @@ class TestValidateGeneratorInput:
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reasoning_tokens_counted_for_text_reasoning_model(monkeypatch):
|
||||
"""Ensure reasoning_tokens usage is derived from thinking token spans."""
|
||||
|
||||
class FakeTokenizer:
|
||||
def __init__(self):
|
||||
self._vocab = {"<think>": 1, "</think>": 2, "reason": 3, "final": 4}
|
||||
|
||||
def get_vocab(self):
|
||||
return self._vocab
|
||||
|
||||
# Force non-harmony, SimpleContext path
|
||||
monkeypatch.setattr(envs, "VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT", False)
|
||||
|
||||
engine_client = MagicMock()
|
||||
model_config = MagicMock()
|
||||
model_config.hf_config.model_type = "test"
|
||||
model_config.hf_text_config = MagicMock()
|
||||
model_config.get_diff_sampling_param.return_value = {}
|
||||
engine_client.model_config = model_config
|
||||
engine_client.input_processor = MagicMock()
|
||||
engine_client.io_processor = MagicMock()
|
||||
engine_client.renderer = MagicMock()
|
||||
|
||||
tokenizer = FakeTokenizer()
|
||||
engine_client.renderer.get_tokenizer.return_value = tokenizer
|
||||
|
||||
models = MagicMock()
|
||||
|
||||
serving = OpenAIServingResponses(
|
||||
engine_client=engine_client,
|
||||
models=models,
|
||||
request_logger=None,
|
||||
chat_template=None,
|
||||
chat_template_content_format="auto",
|
||||
reasoning_parser="qwen3",
|
||||
)
|
||||
|
||||
# Build a SimpleContext with thinking tokens in the output.
|
||||
context = SimpleContext()
|
||||
token_ids = [1, 10, 2, 20] # <think> 10 </think> 20 -> reasoning token count = 1
|
||||
completion = CompletionOutput(
|
||||
index=0,
|
||||
text="<think>reason</think>final",
|
||||
token_ids=token_ids,
|
||||
cumulative_logprob=0.0,
|
||||
logprobs=None,
|
||||
finish_reason="stop",
|
||||
stop_reason=None,
|
||||
)
|
||||
req_output = RequestOutput(
|
||||
request_id="req",
|
||||
prompt="hi",
|
||||
prompt_token_ids=[7, 8],
|
||||
prompt_logprobs=None,
|
||||
outputs=[completion],
|
||||
finished=True,
|
||||
num_cached_tokens=0,
|
||||
)
|
||||
context.append_output(req_output)
|
||||
|
||||
async def dummy_result_generator():
|
||||
yield None
|
||||
|
||||
request = ResponsesRequest(input="hi", tools=[], stream=False)
|
||||
sampling_params = SamplingParams(max_tokens=16)
|
||||
metadata = RequestResponseMetadata(request_id="req")
|
||||
|
||||
response = await serving.responses_full_generator(
|
||||
request=request,
|
||||
sampling_params=sampling_params,
|
||||
result_generator=dummy_result_generator(),
|
||||
context=context,
|
||||
model_name="test-model",
|
||||
tokenizer=tokenizer,
|
||||
request_metadata=metadata,
|
||||
)
|
||||
|
||||
assert response.usage.output_tokens_details.reasoning_tokens == 1
|
||||
|
||||
|
||||
class TestExtractAllowedToolsFromMcpRequests:
|
||||
"""Test class for _extract_allowed_tools_from_mcp_requests function"""
|
||||
|
||||
|
||||
@@ -167,6 +167,23 @@ class TestBaseThinkingReasoningParserMethods:
|
||||
is False
|
||||
)
|
||||
|
||||
def test_count_reasoning_tokens(self, test_tokenizer):
|
||||
"""Count tokens between start/end markers."""
|
||||
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||
start = parser.start_token_id
|
||||
end = parser.end_token_id
|
||||
token_ids = [0, start, 11, 12, end, 99]
|
||||
assert parser.count_reasoning_tokens(token_ids) == 2
|
||||
|
||||
def test_count_reasoning_tokens_nested(self, test_tokenizer):
|
||||
"""Ensure nested thinking spans count all inner tokens safely."""
|
||||
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||
s = parser.start_token_id
|
||||
e = parser.end_token_id
|
||||
token_ids = [s, 1, s, 2, e, 3, e]
|
||||
# Tokens 1,2,3 are inside reasoning (depth>0) => 3 tokens
|
||||
assert parser.count_reasoning_tokens(token_ids) == 3
|
||||
|
||||
def test_extract_content_ids(self, test_tokenizer):
|
||||
"""Test the extract_content_ids method."""
|
||||
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||
|
||||
@@ -280,7 +280,6 @@ class ParsableContext(ConversationContext):
|
||||
self.num_prompt_tokens = 0
|
||||
self.num_output_tokens = 0
|
||||
self.num_cached_tokens = 0
|
||||
# TODO: num_reasoning_tokens is not implemented yet.
|
||||
self.num_reasoning_tokens = 0
|
||||
# not implemented yet for ParsableContext
|
||||
self.all_turn_metrics: list[TurnMetrics] = []
|
||||
@@ -308,12 +307,15 @@ class ParsableContext(ConversationContext):
|
||||
|
||||
self.input_messages: list[ResponseRawMessageAndToken] = []
|
||||
self.output_messages: list[ResponseRawMessageAndToken] = []
|
||||
self._accumulated_token_ids: list[int] = []
|
||||
|
||||
def append_output(self, output: RequestOutput) -> None:
|
||||
self.num_prompt_tokens = len(output.prompt_token_ids or [])
|
||||
self.num_cached_tokens = output.num_cached_tokens or 0
|
||||
self.num_output_tokens += len(output.outputs[0].token_ids or [])
|
||||
self.parser.process(output.outputs[0])
|
||||
output_token_ids = output.outputs[0].token_ids or []
|
||||
self._accumulated_token_ids.extend(output_token_ids)
|
||||
|
||||
# only store if enable_response_messages is True, save memory
|
||||
if self.request.enable_response_messages:
|
||||
|
||||
@@ -759,6 +759,19 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
num_generated_tokens = context.num_output_tokens
|
||||
num_cached_tokens = context.num_cached_tokens
|
||||
num_reasoning_tokens = context.num_reasoning_tokens
|
||||
# For text-based reasoning parsers (e.g., <think>...</think>),
|
||||
# HarmonyContext already counts reasoning tokens via channels.
|
||||
# For Simple/Parsable contexts, derive reasoning_tokens from
|
||||
# accumulated output token IDs using the parser if not already set.
|
||||
if (
|
||||
num_reasoning_tokens == 0
|
||||
and self.parser is not None
|
||||
and self.parser.reasoning_parser_cls is not None
|
||||
and isinstance(context, (SimpleContext, ParsableContext))
|
||||
):
|
||||
reasoning_parser = self.parser.reasoning_parser_cls(tokenizer)
|
||||
accumulated = getattr(context, "_accumulated_token_ids", []) or []
|
||||
num_reasoning_tokens = reasoning_parser.count_reasoning_tokens(accumulated)
|
||||
|
||||
usage = ResponseUsage(
|
||||
input_tokens=num_prompt_tokens,
|
||||
|
||||
@@ -104,6 +104,25 @@ class ReasoningParser:
|
||||
The extracted content from the input_ids.
|
||||
"""
|
||||
|
||||
def count_reasoning_tokens(self, token_ids: Sequence[int]) -> int:
|
||||
"""Count the number of reasoning tokens in a sequence.
|
||||
|
||||
Text-based reasoning models typically wrap their chain-of-thought
|
||||
between special start/end tokens (e.g., ``<think> ... </think>``).
|
||||
Implementations that support reasoning token counting should override
|
||||
this method. The default implementation returns ``0`` so existing
|
||||
parsers remain unchanged unless they explicitly opt in.
|
||||
|
||||
Args:
|
||||
token_ids: Sequence of generated token ids (excluding prompt).
|
||||
|
||||
Returns:
|
||||
int: Number of tokens that belong to reasoning content.
|
||||
"""
|
||||
|
||||
# By default, assume the parser cannot detect reasoning spans.
|
||||
return 0
|
||||
|
||||
@abstractmethod
|
||||
def extract_reasoning(
|
||||
self,
|
||||
|
||||
@@ -175,3 +175,23 @@ class BaseThinkingReasoningParser(ReasoningParser):
|
||||
# If generation stops right after end-of-think, return null content
|
||||
final_content = content or None
|
||||
return reasoning, final_content
|
||||
|
||||
def count_reasoning_tokens(self, token_ids: Sequence[int]) -> int:
|
||||
"""Count tokens that fall within start/end thinking markers.
|
||||
|
||||
Uses a depth counter so nested spans are handled safely and stray end
|
||||
tokens do not drive the counter negative.
|
||||
"""
|
||||
count = 0
|
||||
depth = 0
|
||||
for token_id in token_ids:
|
||||
if token_id == self.start_token_id:
|
||||
depth += 1
|
||||
continue
|
||||
if token_id == self.end_token_id:
|
||||
if depth > 0:
|
||||
depth -= 1
|
||||
continue
|
||||
if depth > 0:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
Reference in New Issue
Block a user