[https://nvbugs/5521799][fix] Trim incorrectly generated harmony messages (#7849)

Signed-off-by: Junyi Xu <219237550+JunyiXu-nv@users.noreply.github.com>
This commit is contained in:
JunyiXu-nv 2025-09-24 16:38:43 +08:00 committed by GitHub
parent 0252cee4c3
commit 6654b78c94
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 68 additions and 28 deletions

View File

@ -6,7 +6,7 @@ import re
import time
import traceback
import uuid
from typing import Any, List, Literal
from typing import Any, List, Literal, Tuple
from openai_harmony import (Author, Conversation, DeveloperContent,
HarmonyEncodingName, HarmonyError, Message,
@ -14,6 +14,7 @@ from openai_harmony import (Author, Conversation, DeveloperContent,
SystemContent, TextContent, ToolDescription,
load_harmony_encoding)
from tensorrt_llm.executor import GenerationResult
from tensorrt_llm.logger import logger
# yapf: disable
@ -29,6 +30,19 @@ from .openai_protocol import (ChatCompletionMessageParam,
# yapf: enable
def _check_channel_valid(generated_channels: List[str], channel: str) -> bool:
if len(generated_channels) == 0 or generated_channels[-1] != channel:
generated_channels.append(channel)
logger.debug(f"generated_channels: {generated_channels}")
if "analysis" in generated_channels and "final" in generated_channels and len(
generated_channels) > 2:
return False
return True
class HarmonyStreamState:
"""
Maintains harmony parsing state for a single request across multiple token batches.
@ -72,12 +86,14 @@ class HarmonyStreamState:
# Track channel states for token preservation
self.has_preamble_content = False
self.current_channel_state = None # "analysis", "commentary_preamble", "commentary_tool", "final"
self.generated_channels = [
] # Track generated channels to avoid generating too many messages
self.channel_started = False # Track if we've sent opening token for current channel
# Track sent arguments for tool call streaming deltas
self.sent_tool_arguments = {} # tool_call_id -> sent_arguments_length
logger.debug("Created HarmonyStreamState for request %s", request_id)
logger.debug(f"Created HarmonyStreamState for request {request_id}")
def get_parser(self) -> StreamableParser:
return self.parser
@ -182,6 +198,10 @@ class HarmonyStreamState:
if not self.parser.last_content_delta:
return None
if not _check_channel_valid(self.generated_channels,
self.parser.current_channel):
return {"should_stop": "Repeated message"}
if self.parser.current_channel == "analysis":
# Analysis channel -> reasoning (no token wrapping needed)
self.current_channel_state = "analysis"
@ -297,6 +317,8 @@ class HarmonyStreamState:
self.parser.last_content_delta,
"current_channel_state":
self.current_channel_state,
"generated_channels":
self.generated_channels,
"channel_started":
self.channel_started,
"has_preamble_content":
@ -1014,12 +1036,16 @@ class HarmonyAdapter:
commentary_preambles = []
tool_calls = []
final_content = ""
generated_channels = []
for msg in harmony_messages:
msg_channel = getattr(msg, 'channel', None)
msg_recipient = getattr(msg, 'recipient', None)
msg_content = getattr(msg, 'content', [])
if not _check_channel_valid(generated_channels, msg_channel):
continue
if msg_channel == "analysis":
for content in msg_content:
if isinstance(content, TextContent):
@ -1299,7 +1325,7 @@ class HarmonyAdapter:
tokens: list[int],
available_tools: list[dict[str, Any]] | None = None,
model_name: str = "harmony-model",
tool_choice: str | None = None) -> list[str]:
tool_choice: str | None = None) -> Tuple[list[str], bool]:
"""
Create properly formatted OpenAI streaming responses from harmony tokens.
@ -1397,12 +1423,15 @@ class HarmonyAdapter:
delta_message.reasoning_content = None
# tool_calls will use default factory (empty list)
should_stop = ("should_stop" in harmony_delta)
# Create the streaming response
choice = ChatCompletionResponseStreamChoice(index=0,
delta=delta_message,
logprobs=None,
finish_reason=None,
stop_reason=None)
choice = ChatCompletionResponseStreamChoice(
index=0,
delta=delta_message,
logprobs=None,
finish_reason="stop" if should_stop else None,
stop_reason=None)
stream_response = ChatCompletionStreamResponse(model=model_name,
choices=[choice],
@ -1412,7 +1441,10 @@ class HarmonyAdapter:
response_json = stream_response.model_dump_json(exclude_none=True)
responses.append(f"data: {response_json}\n\n")
return responses
if should_stop:
return responses, should_stop
return responses, False
def create_stream_state(
self,
@ -1444,7 +1476,7 @@ class HarmonyAdapter:
"""
if request_id in self._stream_states:
del self._stream_states[request_id]
logger.debug("Cleaned up stream state for request %s", request_id)
logger.debug(f"Cleaned up stream state for request {request_id}")
def get_stream_debug_info(self, request_id: str) -> dict[str, Any] | None:
"""Get debug information for a request's stream state."""
@ -1497,11 +1529,11 @@ def get_harmony_adapter():
def handle_streaming_response(tools: List[ChatCompletionToolsParam],
tool_choice: str, outputs: List, model: str,
request_id: str, done: bool,
num_prompt_tokens: int):
tool_choice: str, result: GenerationResult,
model: str, request_id: str, done: bool,
num_prompt_tokens: int) -> List[str]:
first_iteration = True
output = outputs[0]
output = result.outputs[0]
# Convert tools to dictionary format for harmony adapter (standard pattern)
tools_dict = None
@ -1515,15 +1547,25 @@ def handle_streaming_response(tools: List[ChatCompletionToolsParam],
else:
tools_for_parser = tools_dict
def end_streaming(res):
# Clean up state
harmony_adapter.cleanup_stream_state(request_id)
# Append usage info
usage_info = _create_usage_info(num_prompt_tokens, result.outputs)
final_usage_chunk = ChatCompletionStreamResponse(choices=[],
model=model,
usage=usage_info)
final_usage_json = final_usage_chunk.model_dump_json(exclude_none=True)
res.append(f"data: {final_usage_json}\n\n")
# Create OpenAI streaming responses
try:
res = []
if done:
# Clean up state
harmony_adapter.cleanup_stream_state(request_id)
usage_info = _create_usage_info(num_prompt_tokens, outputs)
# Send final message with finish_reason
final_response = ChatCompletionStreamResponse(
model=model,
@ -1538,15 +1580,10 @@ def handle_streaming_response(tools: List[ChatCompletionToolsParam],
final_response_json = final_response.model_dump_json(
exclude_none=True)
final_usage_chunk = ChatCompletionStreamResponse(choices=[],
model=model,
usage=usage_info)
final_usage_json = final_usage_chunk.model_dump_json(
exclude_none=True)
res.append(f"data: {final_response_json}\n\n")
res.append(f"data: {final_usage_json}\n\n")
end_streaming(res)
else:
responses = harmony_adapter.create_openai_streaming_response(
responses, should_stop = harmony_adapter.create_openai_streaming_response(
request_id=request_id,
tokens=output.token_ids_diff,
available_tools=tools_for_parser,
@ -1571,6 +1608,10 @@ def handle_streaming_response(tools: List[ChatCompletionToolsParam],
res.extend(responses)
if should_stop:
end_streaming(res)
result.abort()
return res
except Exception as e:

View File

@ -743,7 +743,6 @@ class OpenAIServer:
async for res in promise:
pp_results = res.outputs[0]._postprocess_result if self.postproc_worker_enabled else post_processor(res, args)
# await self._extract_metrics(res)
for pp_res in pp_results:
yield pp_res

View File

@ -436,7 +436,7 @@ def chat_harmony_streaming_post_processor(
response = handle_streaming_response(
tools=args.tools,
tool_choice=args.tool_choice,
outputs=rsp.outputs,
result=rsp,
model=args.model,
request_id=args.request_id,
done=rsp._done,