mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[https://nvbugs/5521799][fix] Trim incorrectly generated harmony messages (#7849)
Signed-off-by: Junyi Xu <219237550+JunyiXu-nv@users.noreply.github.com>
This commit is contained in:
parent
0252cee4c3
commit
6654b78c94
@ -6,7 +6,7 @@ import re
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import Any, List, Literal
|
||||
from typing import Any, List, Literal, Tuple
|
||||
|
||||
from openai_harmony import (Author, Conversation, DeveloperContent,
|
||||
HarmonyEncodingName, HarmonyError, Message,
|
||||
@ -14,6 +14,7 @@ from openai_harmony import (Author, Conversation, DeveloperContent,
|
||||
SystemContent, TextContent, ToolDescription,
|
||||
load_harmony_encoding)
|
||||
|
||||
from tensorrt_llm.executor import GenerationResult
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
# yapf: disable
|
||||
@ -29,6 +30,19 @@ from .openai_protocol import (ChatCompletionMessageParam,
|
||||
# yapf: enable
|
||||
|
||||
|
||||
def _check_channel_valid(generated_channels: List[str], channel: str) -> bool:
|
||||
|
||||
if len(generated_channels) == 0 or generated_channels[-1] != channel:
|
||||
generated_channels.append(channel)
|
||||
|
||||
logger.debug(f"generated_channels: {generated_channels}")
|
||||
if "analysis" in generated_channels and "final" in generated_channels and len(
|
||||
generated_channels) > 2:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class HarmonyStreamState:
|
||||
"""
|
||||
Maintains harmony parsing state for a single request across multiple token batches.
|
||||
@ -72,12 +86,14 @@ class HarmonyStreamState:
|
||||
# Track channel states for token preservation
|
||||
self.has_preamble_content = False
|
||||
self.current_channel_state = None # "analysis", "commentary_preamble", "commentary_tool", "final"
|
||||
self.generated_channels = [
|
||||
] # Track generated channels to avoid generating too many messages
|
||||
self.channel_started = False # Track if we've sent opening token for current channel
|
||||
|
||||
# Track sent arguments for tool call streaming deltas
|
||||
self.sent_tool_arguments = {} # tool_call_id -> sent_arguments_length
|
||||
|
||||
logger.debug("Created HarmonyStreamState for request %s", request_id)
|
||||
logger.debug(f"Created HarmonyStreamState for request {request_id}")
|
||||
|
||||
def get_parser(self) -> StreamableParser:
|
||||
return self.parser
|
||||
@ -182,6 +198,10 @@ class HarmonyStreamState:
|
||||
if not self.parser.last_content_delta:
|
||||
return None
|
||||
|
||||
if not _check_channel_valid(self.generated_channels,
|
||||
self.parser.current_channel):
|
||||
return {"should_stop": "Repeated message"}
|
||||
|
||||
if self.parser.current_channel == "analysis":
|
||||
# Analysis channel -> reasoning (no token wrapping needed)
|
||||
self.current_channel_state = "analysis"
|
||||
@ -297,6 +317,8 @@ class HarmonyStreamState:
|
||||
self.parser.last_content_delta,
|
||||
"current_channel_state":
|
||||
self.current_channel_state,
|
||||
"generated_channels":
|
||||
self.generated_channels,
|
||||
"channel_started":
|
||||
self.channel_started,
|
||||
"has_preamble_content":
|
||||
@ -1014,12 +1036,16 @@ class HarmonyAdapter:
|
||||
commentary_preambles = []
|
||||
tool_calls = []
|
||||
final_content = ""
|
||||
generated_channels = []
|
||||
|
||||
for msg in harmony_messages:
|
||||
msg_channel = getattr(msg, 'channel', None)
|
||||
msg_recipient = getattr(msg, 'recipient', None)
|
||||
msg_content = getattr(msg, 'content', [])
|
||||
|
||||
if not _check_channel_valid(generated_channels, msg_channel):
|
||||
continue
|
||||
|
||||
if msg_channel == "analysis":
|
||||
for content in msg_content:
|
||||
if isinstance(content, TextContent):
|
||||
@ -1299,7 +1325,7 @@ class HarmonyAdapter:
|
||||
tokens: list[int],
|
||||
available_tools: list[dict[str, Any]] | None = None,
|
||||
model_name: str = "harmony-model",
|
||||
tool_choice: str | None = None) -> list[str]:
|
||||
tool_choice: str | None = None) -> Tuple[list[str], bool]:
|
||||
"""
|
||||
Create properly formatted OpenAI streaming responses from harmony tokens.
|
||||
|
||||
@ -1397,12 +1423,15 @@ class HarmonyAdapter:
|
||||
delta_message.reasoning_content = None
|
||||
# tool_calls will use default factory (empty list)
|
||||
|
||||
should_stop = ("should_stop" in harmony_delta)
|
||||
|
||||
# Create the streaming response
|
||||
choice = ChatCompletionResponseStreamChoice(index=0,
|
||||
delta=delta_message,
|
||||
logprobs=None,
|
||||
finish_reason=None,
|
||||
stop_reason=None)
|
||||
choice = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=delta_message,
|
||||
logprobs=None,
|
||||
finish_reason="stop" if should_stop else None,
|
||||
stop_reason=None)
|
||||
|
||||
stream_response = ChatCompletionStreamResponse(model=model_name,
|
||||
choices=[choice],
|
||||
@ -1412,7 +1441,10 @@ class HarmonyAdapter:
|
||||
response_json = stream_response.model_dump_json(exclude_none=True)
|
||||
responses.append(f"data: {response_json}\n\n")
|
||||
|
||||
return responses
|
||||
if should_stop:
|
||||
return responses, should_stop
|
||||
|
||||
return responses, False
|
||||
|
||||
def create_stream_state(
|
||||
self,
|
||||
@ -1444,7 +1476,7 @@ class HarmonyAdapter:
|
||||
"""
|
||||
if request_id in self._stream_states:
|
||||
del self._stream_states[request_id]
|
||||
logger.debug("Cleaned up stream state for request %s", request_id)
|
||||
logger.debug(f"Cleaned up stream state for request {request_id}")
|
||||
|
||||
def get_stream_debug_info(self, request_id: str) -> dict[str, Any] | None:
|
||||
"""Get debug information for a request's stream state."""
|
||||
@ -1497,11 +1529,11 @@ def get_harmony_adapter():
|
||||
|
||||
|
||||
def handle_streaming_response(tools: List[ChatCompletionToolsParam],
|
||||
tool_choice: str, outputs: List, model: str,
|
||||
request_id: str, done: bool,
|
||||
num_prompt_tokens: int):
|
||||
tool_choice: str, result: GenerationResult,
|
||||
model: str, request_id: str, done: bool,
|
||||
num_prompt_tokens: int) -> List[str]:
|
||||
first_iteration = True
|
||||
output = outputs[0]
|
||||
output = result.outputs[0]
|
||||
|
||||
# Convert tools to dictionary format for harmony adapter (standard pattern)
|
||||
tools_dict = None
|
||||
@ -1515,15 +1547,25 @@ def handle_streaming_response(tools: List[ChatCompletionToolsParam],
|
||||
else:
|
||||
tools_for_parser = tools_dict
|
||||
|
||||
def end_streaming(res):
|
||||
# Clean up state
|
||||
harmony_adapter.cleanup_stream_state(request_id)
|
||||
|
||||
# Append usage info
|
||||
usage_info = _create_usage_info(num_prompt_tokens, result.outputs)
|
||||
|
||||
final_usage_chunk = ChatCompletionStreamResponse(choices=[],
|
||||
model=model,
|
||||
usage=usage_info)
|
||||
|
||||
final_usage_json = final_usage_chunk.model_dump_json(exclude_none=True)
|
||||
|
||||
res.append(f"data: {final_usage_json}\n\n")
|
||||
|
||||
# Create OpenAI streaming responses
|
||||
try:
|
||||
res = []
|
||||
if done:
|
||||
# Clean up state
|
||||
harmony_adapter.cleanup_stream_state(request_id)
|
||||
|
||||
usage_info = _create_usage_info(num_prompt_tokens, outputs)
|
||||
|
||||
# Send final message with finish_reason
|
||||
final_response = ChatCompletionStreamResponse(
|
||||
model=model,
|
||||
@ -1538,15 +1580,10 @@ def handle_streaming_response(tools: List[ChatCompletionToolsParam],
|
||||
|
||||
final_response_json = final_response.model_dump_json(
|
||||
exclude_none=True)
|
||||
final_usage_chunk = ChatCompletionStreamResponse(choices=[],
|
||||
model=model,
|
||||
usage=usage_info)
|
||||
final_usage_json = final_usage_chunk.model_dump_json(
|
||||
exclude_none=True)
|
||||
res.append(f"data: {final_response_json}\n\n")
|
||||
res.append(f"data: {final_usage_json}\n\n")
|
||||
end_streaming(res)
|
||||
else:
|
||||
responses = harmony_adapter.create_openai_streaming_response(
|
||||
responses, should_stop = harmony_adapter.create_openai_streaming_response(
|
||||
request_id=request_id,
|
||||
tokens=output.token_ids_diff,
|
||||
available_tools=tools_for_parser,
|
||||
@ -1571,6 +1608,10 @@ def handle_streaming_response(tools: List[ChatCompletionToolsParam],
|
||||
|
||||
res.extend(responses)
|
||||
|
||||
if should_stop:
|
||||
end_streaming(res)
|
||||
result.abort()
|
||||
|
||||
return res
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@ -743,7 +743,6 @@ class OpenAIServer:
|
||||
|
||||
async for res in promise:
|
||||
pp_results = res.outputs[0]._postprocess_result if self.postproc_worker_enabled else post_processor(res, args)
|
||||
# await self._extract_metrics(res)
|
||||
for pp_res in pp_results:
|
||||
yield pp_res
|
||||
|
||||
|
||||
@ -436,7 +436,7 @@ def chat_harmony_streaming_post_processor(
|
||||
response = handle_streaming_response(
|
||||
tools=args.tools,
|
||||
tool_choice=args.tool_choice,
|
||||
outputs=rsp.outputs,
|
||||
result=rsp,
|
||||
model=args.model,
|
||||
request_id=args.request_id,
|
||||
done=rsp._done,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user