diff --git a/tensorrt_llm/serve/harmony_adapter.py b/tensorrt_llm/serve/harmony_adapter.py index 1ba163581e..f9e168a394 100644 --- a/tensorrt_llm/serve/harmony_adapter.py +++ b/tensorrt_llm/serve/harmony_adapter.py @@ -1521,10 +1521,10 @@ class HarmonyAdapter: return True -_SERVE_HARMONY_ADAPTER: HarmonyAdapter = None +_SERVE_HARMONY_ADAPTER: HarmonyAdapter | None = None -def get_harmony_adapter(): +def get_harmony_adapter() -> HarmonyAdapter: global _SERVE_HARMONY_ADAPTER if _SERVE_HARMONY_ADAPTER is None: _SERVE_HARMONY_ADAPTER = HarmonyAdapter() @@ -1535,8 +1535,8 @@ def get_harmony_adapter(): def handle_streaming_response(tools: List[ChatCompletionToolsParam], tool_choice: str, result: GenerationResult, model: str, request_id: str, done: bool, - num_prompt_tokens: int) -> List[str]: - first_iteration = True + num_prompt_tokens: int, + first_iteration: bool) -> List[str]: output = result.outputs[0] # Convert tools to dictionary format for harmony adapter (standard pattern) diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index 524c7440ab..9d3f60fcb0 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -873,11 +873,12 @@ class OpenAIServer: return chat_response async def create_streaming_generator(promise: RequestOutput, postproc_params: PostprocParams): - if not self.postproc_worker_enabled: - post_processor, args = postproc_params.post_processor, postproc_params.postproc_args - async for res in promise: - pp_results = res.outputs[0]._postprocess_result if self.postproc_worker_enabled else post_processor(res, args) + if not self.postproc_worker_enabled: + post_processor, args = postproc_params.post_processor, postproc_params.postproc_args + pp_results = post_processor(res, args) + else: + pp_results = res.outputs[0]._postprocess_result for pp_res in pp_results: yield pp_res diff --git a/tensorrt_llm/serve/postprocess_handlers.py b/tensorrt_llm/serve/postprocess_handlers.py index 38c6c93563..bacce813b6 100644 --- a/tensorrt_llm/serve/postprocess_handlers.py +++ b/tensorrt_llm/serve/postprocess_handlers.py @@ -593,7 +593,9 @@ def chat_harmony_streaming_post_processor( request_id=args.request_id, done=rsp._done, num_prompt_tokens=args.num_prompt_tokens, + first_iteration=args.first_iteration, ) + args.first_iteration = False return response diff --git a/tests/unittest/llmapi/apps/_test_openai_chat_harmony.py b/tests/unittest/llmapi/apps/_test_openai_chat_harmony.py index fc550f0824..5eadaf88ae 100644 --- a/tests/unittest/llmapi/apps/_test_openai_chat_harmony.py +++ b/tests/unittest/llmapi/apps/_test_openai_chat_harmony.py @@ -161,15 +161,29 @@ async def test_streaming(client: openai.AsyncOpenAI, model: str): }], stream=True, ) - collected_chunks = [] collected_messages = [] + first_iteration = True async for chunk in response: - # Last streaming response will only contains usage info - if len(chunk.choices) <= 0: - continue - - collected_chunks.append(chunk) - collected_messages.append(chunk.choices[0].delta) + if chunk.choices: + delta = chunk.choices[0].delta + if first_iteration: + assert delta.role == "assistant", ValueError( + "Expected role 'assistant' for first iteration") + collected_messages.append(delta) + first_iteration = False + else: + assert delta.role is None, ValueError( + "Expected no role except for first iteration") + collected_messages.append(delta) + else: + assert hasattr(chunk, "usage"), ValueError( + "Expected usage info in last streaming response") + assert chunk.usage.prompt_tokens is not None + assert chunk.usage.completion_tokens is not None + assert chunk.usage.total_tokens is not None + assert chunk.usage.prompt_tokens > 0 + assert chunk.usage.completion_tokens > 0 + assert chunk.usage.total_tokens > 0 full_response = "".join([ m.content for m in collected_messages diff --git a/tests/unittest/llmapi/apps/test_harmony_channel_validation.py b/tests/unittest/llmapi/apps/test_harmony_channel_validation.py index 6e6046a3ac..bb16cb4bc9 100644 --- a/tests/unittest/llmapi/apps/test_harmony_channel_validation.py +++ b/tests/unittest/llmapi/apps/test_harmony_channel_validation.py @@ -362,6 +362,7 @@ class TestHandleStreamingResponse: request_id=request_id, done=False, # Not done yet, still streaming num_prompt_tokens=10, + first_iteration=True, ) # CRITICAL ASSERTION: result.abort() should be called @@ -424,6 +425,7 @@ class TestHandleStreamingResponse: request_id=request_id, done=False, num_prompt_tokens=10, + first_iteration=True, ) # CRITICAL ASSERTION: result.abort() should NOT be called