mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Merge 59261df527 into 6df2c8a074
This commit is contained in:
commit
2453a1d8e3
@ -624,6 +624,9 @@ class DetokenizedGenerationResultBase(GenerationResultBase):
|
||||
self._streaming = streaming
|
||||
|
||||
def _handle_response(self, response: "GenerationExecutor.Response"):
|
||||
# Save token lengths before processing to detect which outputs received new tokens
|
||||
prev_token_lens = {id(o): len(o.token_ids) for o in self._outputs}
|
||||
|
||||
GenerationResultBase._handle_response(self, response)
|
||||
|
||||
# The postprocess has been performed, return directly
|
||||
@ -638,7 +641,15 @@ class DetokenizedGenerationResultBase(GenerationResultBase):
|
||||
}
|
||||
if self.sampling_params.detokenize and self.tokenizer is not None:
|
||||
for beam_output in self.outputs:
|
||||
# Always update _last_text_len to prevent stale text_diff
|
||||
beam_output._last_text_len = len(beam_output.text)
|
||||
# For n > 1 streaming: only detokenize outputs that received new tokens
|
||||
# to prevent re-decoding the same tokens multiple times
|
||||
output_received_new_tokens = len(
|
||||
beam_output.token_ids) != prev_token_lens.get(
|
||||
id(beam_output), 0)
|
||||
if not output_received_new_tokens:
|
||||
continue
|
||||
if hasattr(
|
||||
self.tokenizer, 'decode_incrementally'
|
||||
) and self._streaming and not self.sampling_params.use_beam_search:
|
||||
|
||||
@ -5,6 +5,7 @@ from typing import List
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
from utils.util import similar
|
||||
|
||||
from ..test_llm import get_model_path
|
||||
from .openai_server import RemoteOpenAIServer
|
||||
@ -204,6 +205,39 @@ async def test_batch_completions_streaming(async_client: openai.AsyncOpenAI,
|
||||
assert texts[0] == texts[1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
@pytest.mark.parametrize("prompts", [["Hello, my name is"] * 2])
|
||||
async def test_batch_completions_with_option_n_streaming(
|
||||
async_client: openai.AsyncOpenAI, model_name, prompts):
|
||||
# Use non-stream single generation as reference
|
||||
completion_ref = await async_client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompts[0],
|
||||
max_tokens=5,
|
||||
temperature=0.0001,
|
||||
)
|
||||
text_ref = completion_ref.choices[0].text
|
||||
|
||||
# test n>1 with streaming
|
||||
batch = await async_client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompts,
|
||||
n=3, # number of completions to generate for each prompt.
|
||||
max_tokens=5,
|
||||
temperature=0.0001,
|
||||
stream=True,
|
||||
)
|
||||
texts = [""] * 6 # 2 prompts × 3 generations per prompt = 6 choices
|
||||
async for chunk in batch:
|
||||
assert len(chunk.choices) == 1
|
||||
choice = chunk.choices[0]
|
||||
texts[choice.index] += choice.text
|
||||
|
||||
# Check all generations are consistent with the reference
|
||||
for text in texts:
|
||||
assert similar(text, text_ref, threshold=0.8)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_completion_stream_options(async_client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user