[#10614][fix] gpt_oss first iteration streaming in trtllm-serve (#10808)

Signed-off-by: Pengyun Lin <81065165+LinPoly@users.noreply.github.com>
This commit is contained in:
Pengyun Lin 2026-01-26 20:53:11 +08:00 committed by GitHub
parent 5d7a5e6800
commit ce37e27066
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 34 additions and 15 deletions

View File

@ -1521,10 +1521,10 @@ class HarmonyAdapter:
return True
_SERVE_HARMONY_ADAPTER: HarmonyAdapter = None
_SERVE_HARMONY_ADAPTER: HarmonyAdapter | None = None
def get_harmony_adapter():
def get_harmony_adapter() -> HarmonyAdapter:
global _SERVE_HARMONY_ADAPTER
if _SERVE_HARMONY_ADAPTER is None:
_SERVE_HARMONY_ADAPTER = HarmonyAdapter()
@ -1535,8 +1535,8 @@ def get_harmony_adapter():
def handle_streaming_response(tools: List[ChatCompletionToolsParam],
tool_choice: str, result: GenerationResult,
model: str, request_id: str, done: bool,
num_prompt_tokens: int) -> List[str]:
first_iteration = True
num_prompt_tokens: int,
first_iteration: bool) -> List[str]:
output = result.outputs[0]
# Convert tools to dictionary format for harmony adapter (standard pattern)

View File

@ -873,11 +873,12 @@ class OpenAIServer:
return chat_response
async def create_streaming_generator(promise: RequestOutput, postproc_params: PostprocParams):
if not self.postproc_worker_enabled:
post_processor, args = postproc_params.post_processor, postproc_params.postproc_args
async for res in promise:
pp_results = res.outputs[0]._postprocess_result if self.postproc_worker_enabled else post_processor(res, args)
if not self.postproc_worker_enabled:
post_processor, args = postproc_params.post_processor, postproc_params.postproc_args
pp_results = post_processor(res, args)
else:
pp_results = res.outputs[0]._postprocess_result
for pp_res in pp_results:
yield pp_res

View File

@ -593,7 +593,9 @@ def chat_harmony_streaming_post_processor(
request_id=args.request_id,
done=rsp._done,
num_prompt_tokens=args.num_prompt_tokens,
first_iteration=args.first_iteration,
)
args.first_iteration = False
return response

View File

@ -161,15 +161,29 @@ async def test_streaming(client: openai.AsyncOpenAI, model: str):
}],
stream=True,
)
collected_chunks = []
collected_messages = []
first_iteration = True
async for chunk in response:
# Last streaming response will only contains usage info
if len(chunk.choices) <= 0:
continue
collected_chunks.append(chunk)
collected_messages.append(chunk.choices[0].delta)
if chunk.choices:
delta = chunk.choices[0].delta
if first_iteration:
assert delta.role == "assistant", ValueError(
"Expected role 'assistant' for first iteration")
collected_messages.append(delta)
first_iteration = False
else:
assert delta.role is None, ValueError(
"Expected no role except for first iteration")
collected_messages.append(delta)
else:
assert hasattr(chunk, "usage"), ValueError(
"Expected usage info in last streaming response")
assert chunk.usage.prompt_tokens is not None
assert chunk.usage.completion_tokens is not None
assert chunk.usage.total_tokens is not None
assert chunk.usage.prompt_tokens > 0
assert chunk.usage.completion_tokens > 0
assert chunk.usage.total_tokens > 0
full_response = "".join([
m.content for m in collected_messages

View File

@ -362,6 +362,7 @@ class TestHandleStreamingResponse:
request_id=request_id,
done=False, # Not done yet, still streaming
num_prompt_tokens=10,
first_iteration=True,
)
# CRITICAL ASSERTION: result.abort() should be called
@ -424,6 +425,7 @@ class TestHandleStreamingResponse:
request_id=request_id,
done=False,
num_prompt_tokens=10,
first_iteration=True,
)
# CRITICAL ASSERTION: result.abort() should NOT be called