This commit is contained in:
Simeng Liu 2026-01-13 21:25:08 +08:00 committed by GitHub
commit 2453a1d8e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 45 additions and 0 deletions

View File

@ -624,6 +624,9 @@ class DetokenizedGenerationResultBase(GenerationResultBase):
self._streaming = streaming
def _handle_response(self, response: "GenerationExecutor.Response"):
# Save token lengths before processing to detect which outputs received new tokens
prev_token_lens = {id(o): len(o.token_ids) for o in self._outputs}
GenerationResultBase._handle_response(self, response)
# The postprocess has been performed, return directly
@ -638,7 +641,15 @@ class DetokenizedGenerationResultBase(GenerationResultBase):
}
if self.sampling_params.detokenize and self.tokenizer is not None:
for beam_output in self.outputs:
# Always update _last_text_len to prevent stale text_diff
beam_output._last_text_len = len(beam_output.text)
# For n > 1 streaming: only detokenize outputs that received new tokens
# to prevent re-decoding the same tokens multiple times
output_received_new_tokens = len(
beam_output.token_ids) != prev_token_lens.get(
id(beam_output), 0)
if not output_received_new_tokens:
continue
if hasattr(
self.tokenizer, 'decode_incrementally'
) and self._streaming and not self.sampling_params.use_beam_search:

View File

@ -5,6 +5,7 @@ from typing import List
import openai
import pytest
from utils.util import similar
from ..test_llm import get_model_path
from .openai_server import RemoteOpenAIServer
@ -204,6 +205,39 @@ async def test_batch_completions_streaming(async_client: openai.AsyncOpenAI,
assert texts[0] == texts[1]
@pytest.mark.asyncio(loop_scope="module")
@pytest.mark.parametrize("prompts", [["Hello, my name is"] * 2])
async def test_batch_completions_with_option_n_streaming(
async_client: openai.AsyncOpenAI, model_name, prompts):
# Use non-stream single generation as reference
completion_ref = await async_client.completions.create(
model=model_name,
prompt=prompts[0],
max_tokens=5,
temperature=0.0001,
)
text_ref = completion_ref.choices[0].text
# test n>1 with streaming
batch = await async_client.completions.create(
model=model_name,
prompt=prompts,
n=3, # number of completions to generate for each prompt.
max_tokens=5,
temperature=0.0001,
stream=True,
)
texts = [""] * 6 # 2 prompts × 3 generations per prompt = 6 choices
async for chunk in batch:
assert len(chunk.choices) == 1
choice = chunk.choices[0]
texts[choice.index] += choice.text
# Check all generations are consistent with the reference
for text in texts:
assert similar(text, text_ref, threshold=0.8)
@pytest.mark.asyncio(loop_scope="module")
async def test_completion_stream_options(async_client: openai.AsyncOpenAI,
model_name: str):