diff --git a/tensorrt_llm/serve/harmony_adapter.py b/tensorrt_llm/serve/harmony_adapter.py index 2949965d72..9299b2b68a 100644 --- a/tensorrt_llm/serve/harmony_adapter.py +++ b/tensorrt_llm/serve/harmony_adapter.py @@ -6,7 +6,7 @@ import re import time import traceback import uuid -from typing import Any, List, Literal +from typing import Any, List, Literal, Tuple from openai_harmony import (Author, Conversation, DeveloperContent, HarmonyEncodingName, HarmonyError, Message, @@ -14,6 +14,7 @@ from openai_harmony import (Author, Conversation, DeveloperContent, SystemContent, TextContent, ToolDescription, load_harmony_encoding) +from tensorrt_llm.executor import GenerationResult from tensorrt_llm.logger import logger # yapf: disable @@ -29,6 +30,19 @@ from .openai_protocol import (ChatCompletionMessageParam, # yapf: enable +def _check_channel_valid(generated_channels: List[str], channel: str) -> bool: + + if len(generated_channels) == 0 or generated_channels[-1] != channel: + generated_channels.append(channel) + + logger.debug(f"generated_channels: {generated_channels}") + if "analysis" in generated_channels and "final" in generated_channels and len( + generated_channels) > 2: + return False + + return True + + class HarmonyStreamState: """ Maintains harmony parsing state for a single request across multiple token batches. @@ -72,12 +86,14 @@ class HarmonyStreamState: # Track channel states for token preservation self.has_preamble_content = False self.current_channel_state = None # "analysis", "commentary_preamble", "commentary_tool", "final" + self.generated_channels = [ + ] # Track generated channels to avoid generating too many messages self.channel_started = False # Track if we've sent opening token for current channel # Track sent arguments for tool call streaming deltas self.sent_tool_arguments = {} # tool_call_id -> sent_arguments_length - logger.debug("Created HarmonyStreamState for request %s", request_id) + logger.debug(f"Created HarmonyStreamState for request {request_id}") def get_parser(self) -> StreamableParser: return self.parser @@ -182,6 +198,10 @@ class HarmonyStreamState: if not self.parser.last_content_delta: return None + if not _check_channel_valid(self.generated_channels, + self.parser.current_channel): + return {"should_stop": "Repeated message"} + if self.parser.current_channel == "analysis": # Analysis channel -> reasoning (no token wrapping needed) self.current_channel_state = "analysis" @@ -297,6 +317,8 @@ class HarmonyStreamState: self.parser.last_content_delta, "current_channel_state": self.current_channel_state, + "generated_channels": + self.generated_channels, "channel_started": self.channel_started, "has_preamble_content": @@ -1014,12 +1036,16 @@ class HarmonyAdapter: commentary_preambles = [] tool_calls = [] final_content = "" + generated_channels = [] for msg in harmony_messages: msg_channel = getattr(msg, 'channel', None) msg_recipient = getattr(msg, 'recipient', None) msg_content = getattr(msg, 'content', []) + if not _check_channel_valid(generated_channels, msg_channel): + continue + if msg_channel == "analysis": for content in msg_content: if isinstance(content, TextContent): @@ -1299,7 +1325,7 @@ class HarmonyAdapter: tokens: list[int], available_tools: list[dict[str, Any]] | None = None, model_name: str = "harmony-model", - tool_choice: str | None = None) -> list[str]: + tool_choice: str | None = None) -> Tuple[list[str], bool]: """ Create properly formatted OpenAI streaming responses from harmony tokens. @@ -1397,12 +1423,15 @@ class HarmonyAdapter: delta_message.reasoning_content = None # tool_calls will use default factory (empty list) + should_stop = ("should_stop" in harmony_delta) + # Create the streaming response - choice = ChatCompletionResponseStreamChoice(index=0, - delta=delta_message, - logprobs=None, - finish_reason=None, - stop_reason=None) + choice = ChatCompletionResponseStreamChoice( + index=0, + delta=delta_message, + logprobs=None, + finish_reason="stop" if should_stop else None, + stop_reason=None) stream_response = ChatCompletionStreamResponse(model=model_name, choices=[choice], @@ -1412,7 +1441,10 @@ class HarmonyAdapter: response_json = stream_response.model_dump_json(exclude_none=True) responses.append(f"data: {response_json}\n\n") - return responses + if should_stop: + return responses, should_stop + + return responses, False def create_stream_state( self, @@ -1444,7 +1476,7 @@ class HarmonyAdapter: """ if request_id in self._stream_states: del self._stream_states[request_id] - logger.debug("Cleaned up stream state for request %s", request_id) + logger.debug(f"Cleaned up stream state for request {request_id}") def get_stream_debug_info(self, request_id: str) -> dict[str, Any] | None: """Get debug information for a request's stream state.""" @@ -1497,11 +1529,11 @@ def get_harmony_adapter(): def handle_streaming_response(tools: List[ChatCompletionToolsParam], - tool_choice: str, outputs: List, model: str, - request_id: str, done: bool, - num_prompt_tokens: int): + tool_choice: str, result: GenerationResult, + model: str, request_id: str, done: bool, + num_prompt_tokens: int) -> List[str]: first_iteration = True - output = outputs[0] + output = result.outputs[0] # Convert tools to dictionary format for harmony adapter (standard pattern) tools_dict = None @@ -1515,15 +1547,25 @@ def handle_streaming_response(tools: List[ChatCompletionToolsParam], else: tools_for_parser = tools_dict + def end_streaming(res): + # Clean up state + harmony_adapter.cleanup_stream_state(request_id) + + # Append usage info + usage_info = _create_usage_info(num_prompt_tokens, result.outputs) + + final_usage_chunk = ChatCompletionStreamResponse(choices=[], + model=model, + usage=usage_info) + + final_usage_json = final_usage_chunk.model_dump_json(exclude_none=True) + + res.append(f"data: {final_usage_json}\n\n") + # Create OpenAI streaming responses try: res = [] if done: - # Clean up state - harmony_adapter.cleanup_stream_state(request_id) - - usage_info = _create_usage_info(num_prompt_tokens, outputs) - # Send final message with finish_reason final_response = ChatCompletionStreamResponse( model=model, @@ -1538,15 +1580,10 @@ def handle_streaming_response(tools: List[ChatCompletionToolsParam], final_response_json = final_response.model_dump_json( exclude_none=True) - final_usage_chunk = ChatCompletionStreamResponse(choices=[], - model=model, - usage=usage_info) - final_usage_json = final_usage_chunk.model_dump_json( - exclude_none=True) res.append(f"data: {final_response_json}\n\n") - res.append(f"data: {final_usage_json}\n\n") + end_streaming(res) else: - responses = harmony_adapter.create_openai_streaming_response( + responses, should_stop = harmony_adapter.create_openai_streaming_response( request_id=request_id, tokens=output.token_ids_diff, available_tools=tools_for_parser, @@ -1571,6 +1608,10 @@ def handle_streaming_response(tools: List[ChatCompletionToolsParam], res.extend(responses) + if should_stop: + end_streaming(res) + result.abort() + return res except Exception as e: diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index 1187f3324e..8d2977f6b4 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -743,7 +743,6 @@ class OpenAIServer: async for res in promise: pp_results = res.outputs[0]._postprocess_result if self.postproc_worker_enabled else post_processor(res, args) - # await self._extract_metrics(res) for pp_res in pp_results: yield pp_res diff --git a/tensorrt_llm/serve/postprocess_handlers.py b/tensorrt_llm/serve/postprocess_handlers.py index a2819fad45..bced49af3c 100644 --- a/tensorrt_llm/serve/postprocess_handlers.py +++ b/tensorrt_llm/serve/postprocess_handlers.py @@ -436,7 +436,7 @@ def chat_harmony_streaming_post_processor( response = handle_streaming_response( tools=args.tools, tool_choice=args.tool_choice, - outputs=rsp.outputs, + result=rsp, model=args.model, request_id=args.request_id, done=rsp._done,